remove mmid parameters from mm forward

This commit is contained in:
slaren 2023-12-15 10:49:16 +01:00
parent 66ce753abd
commit 3de63cf103

27
ggml.c
View file

@ -9580,16 +9580,11 @@ static bool ggml_compute_forward_mul_mat_use_blas(
}
#endif
// off1 = offset in i11 and i1
// cne1 = ne11 and ne1
// in a normal matrix multiplication, off1 = 0 and cne1 = ne1
// during GGML_TASK_INIT, the full src1 is converted regardless of off1 and cne1
static void ggml_compute_forward_mul_mat(
const struct ggml_compute_params * params,
const struct ggml_tensor * src0,
const struct ggml_tensor * src1,
struct ggml_tensor * dst,
int64_t off1, int64_t cne1) {
struct ggml_tensor * dst) {
int64_t t0 = ggml_perf_time_us();
UNUSED(t0);
@ -9657,9 +9652,9 @@ static void ggml_compute_forward_mul_mat(
const int64_t i03 = i13/r3;
const int64_t i02 = i12/r2;
const void * x = (char *) src0->data + i02*nb02 + i03*nb03;
const float * y = (float *) ((char *) src1->data + off1*nb11 + i12*nb12 + i13*nb13);
float * d = (float *) ((char *) dst->data + off1*nb1 + i12*nb2 + i13*nb3);
const void * x = (char *) src0->data + i02*nb02 + i03*nb03;
const float * y = (float *) ((char *) src1->data + i12*nb12 + i13*nb13);
float * d = (float *) ((char *) dst->data + i12*nb2 + i13*nb3);
if (type != GGML_TYPE_F32) {
float * const wdata = params->wdata;
@ -9676,7 +9671,7 @@ static void ggml_compute_forward_mul_mat(
}
cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans,
cne1, ne01, ne10,
ne1, ne01, ne10,
1.0f, y, ne10,
x, ne00,
0.0f, d, ne01);
@ -9717,8 +9712,8 @@ static void ggml_compute_forward_mul_mat(
const void * wdata = (src1->type == vec_dot_type) ? src1->data : params->wdata;
const size_t row_size = ggml_row_size(vec_dot_type, ne10);
const int64_t nr0 = ne01; // src0 rows
const int64_t nr1 = cne1*ne12*ne13; // src1 rows
const int64_t nr0 = ne01; // src0 rows
const int64_t nr1 = ne1*ne12*ne13; // src1 rows
//printf("nr0 = %lld, nr1 = %lld\n", nr0, nr1);
@ -9760,9 +9755,9 @@ static void ggml_compute_forward_mul_mat(
for (int64_t iir1 = ir110; iir1 < ir111; iir1 += blck_1) {
for (int64_t iir0 = ir010; iir0 < ir011; iir0 += blck_0) {
for (int64_t ir1 = iir1; ir1 < iir1 + blck_1 && ir1 < ir111; ++ir1) {
const int64_t i13 = (ir1/(ne12*cne1));
const int64_t i12 = (ir1 - i13*ne12*cne1)/cne1;
const int64_t i11 = (ir1 - i13*ne12*cne1 - i12*cne1) + off1;
const int64_t i13 = (ir1/(ne12*ne1));
const int64_t i12 = (ir1 - i13*ne12*ne1)/ne1;
const int64_t i11 = (ir1 - i13*ne12*ne1 - i12*ne1);
// broadcast src0 into src1
const int64_t i03 = i13/r3;
@ -14344,7 +14339,7 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
} break;
case GGML_OP_MUL_MAT:
{
ggml_compute_forward_mul_mat(params, tensor->src[0], tensor->src[1], tensor, 0, tensor->ne[1]);
ggml_compute_forward_mul_mat(params, tensor->src[0], tensor->src[1], tensor);
} break;
case GGML_OP_MUL_MAT_ID:
{