implement ggml_compute_forward_out_prod_q_f32

This commit is contained in:
xaedes 2023-08-16 22:00:37 +02:00
parent 79ad888768
commit 83cb9ed4f5
No known key found for this signature in database
GPG key ID: 30030EDD817EA2B1

136
ggml.c
View file

@ -10623,8 +10623,8 @@ static void ggml_compute_forward_out_prod_f32(
const struct ggml_tensor * src0, const struct ggml_tensor * src0,
const struct ggml_tensor * src1, const struct ggml_tensor * src1,
struct ggml_tensor * dst) { struct ggml_tensor * dst) {
int64_t t0 = ggml_perf_time_us(); // int64_t t0 = ggml_perf_time_us();
UNUSED(t0); // UNUSED(t0);
GGML_TENSOR_BINARY_OP_LOCALS; GGML_TENSOR_BINARY_OP_LOCALS;
@ -10725,6 +10725,116 @@ static void ggml_compute_forward_out_prod_f32(
//} //}
} }
static void ggml_compute_forward_out_prod_q_f32(
const struct ggml_compute_params * params,
const struct ggml_tensor * src0,
const struct ggml_tensor * src1,
struct ggml_tensor * dst) {
// int64_t t0 = ggml_perf_time_us();
// UNUSED(t0);
GGML_TENSOR_BINARY_OP_LOCALS;
const int ith = params->ith;
const int nth = params->nth;
const enum ggml_type type = src0->type;
ggml_to_float_t const dequantize_row_q = type_traits[type].to_float;
GGML_ASSERT(ne02 == ne12);
GGML_ASSERT(ne03 == ne13);
GGML_ASSERT(ne2 == ne12);
GGML_ASSERT(ne3 == ne13);
// we don't support permuted src0 dim0
GGML_ASSERT(nb00 == GGML_TYPE_SIZE[type]);
// dst dim0 cannot be transposed or permuted
GGML_ASSERT(nb0 == sizeof(float));
// GGML_ASSERT(nb0 <= nb1);
// GGML_ASSERT(nb1 <= nb2);
// GGML_ASSERT(nb2 <= nb3);
GGML_ASSERT(ne0 == ne00);
GGML_ASSERT(ne1 == ne10);
GGML_ASSERT(ne2 == ne02);
GGML_ASSERT(ne3 == ne03);
// nb01 >= nb00 - src0 is not transposed
// compute by src0 rows
// TODO: #if defined(GGML_USE_CUBLAS) ggml_cuda_out_prod
// TODO: #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CLBLAST)
if (params->type == GGML_TASK_INIT) {
ggml_vec_set_f32(ne0*ne1*ne2*ne3, dst->data, 0);
return;
}
if (params->type == GGML_TASK_FINALIZE) {
return;
}
// parallelize by last three dimensions
// total rows in dst
const int64_t nr = ne1*ne2*ne3;
// rows per thread
const int64_t dr = (nr + nth - 1)/nth;
// row range for this thread
const int64_t ir0 = dr*ith;
const int64_t ir1 = MIN(ir0 + dr, nr);
// dst[:,:,:,:] = 0
// for i2,i3:
// for i1:
// for i01:
// for i0:
// dst[i0,i1,i2,i3] += src0[i0,i01,i2,i3] * src1[i1,i01,i2,i3]
float * wdata = (float *) params->wdata + (ne0 + CACHE_LINE_SIZE_F32) * ith;
for (int64_t ir = ir0; ir < ir1; ++ir) {
// dst indices
const int64_t i3 = ir/(ne2*ne1);
const int64_t i2 = (ir - i3*ne2*ne1)/ne1;
const int64_t i1 = (ir - i3*ne2*ne1 - i2*ne1);
const int64_t i02 = i2;
const int64_t i03 = i3;
//const int64_t i10 = i1;
const int64_t i12 = i2;
const int64_t i13 = i3;
for (int64_t i01 = 0; i01 < ne01; ++i01) {
const int64_t i11 = i01;
float * s0 = (float *) ((char *) src0->data + ( i01*nb01 + i02*nb02 + i03*nb03));
float * s1 = (float *) ((char *) src1->data + (i1*nb10 + i11*nb11 + i12*nb12 + i13*nb13));
float * d = (float *) ((char *) dst->data + ( i1*nb1 + i2*nb2 + i3*nb3));
dequantize_row_q(s0, wdata, ne0);
ggml_vec_mad_f32(ne0, d, wdata, *s1);
}
}
//int64_t t1 = ggml_perf_time_us();
//static int64_t acc = 0;
//acc += t1 - t0;
//if (t1 - t0 > 10) {
// printf("\n");
// printf("ne00 = %5d, ne01 = %5d, ne02 = %5d, ne03 = %5d\n", ne00, ne01, ne02, ne03);
// printf("nb00 = %5d, nb01 = %5d, nb02 = %5d, nb03 = %5d\n", nb00, nb01, nb02, nb03);
// printf("ne10 = %5d, ne11 = %5d, ne12 = %5d, ne13 = %5d\n", ne10, ne11, ne12, ne13);
// printf("nb10 = %5d, nb11 = %5d, nb12 = %5d, nb13 = %5d\n", nb10, nb11, nb12, nb13);
// printf("XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX task %d/%d: %d us, acc = %d\n", ith, nth, (int) (t1 - t0), (int) acc);
//}
}
static void ggml_compute_forward_out_prod( static void ggml_compute_forward_out_prod(
const struct ggml_compute_params * params, const struct ggml_compute_params * params,
const struct ggml_tensor * src0, const struct ggml_tensor * src0,
@ -10736,10 +10846,13 @@ static void ggml_compute_forward_out_prod(
case GGML_TYPE_Q5_0: case GGML_TYPE_Q5_0:
case GGML_TYPE_Q5_1: case GGML_TYPE_Q5_1:
case GGML_TYPE_Q8_0: case GGML_TYPE_Q8_0:
case GGML_TYPE_Q8_1: case GGML_TYPE_Q2_K:
case GGML_TYPE_Q3_K:
case GGML_TYPE_Q4_K:
case GGML_TYPE_Q5_K:
case GGML_TYPE_Q6_K:
{ {
GGML_ASSERT(false); // todo ggml_compute_forward_out_prod_q_f32(params, src0, src1, dst);
// ggml_compute_forward_out_prod_q_f32(params, src0, src1, dst);
} break; } break;
case GGML_TYPE_F16: case GGML_TYPE_F16:
{ {
@ -16216,7 +16329,6 @@ struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) {
n_tasks = n_threads; n_tasks = n_threads;
} break; } break;
case GGML_OP_MUL_MAT: case GGML_OP_MUL_MAT:
case GGML_OP_OUT_PROD:
{ {
n_tasks = n_threads; n_tasks = n_threads;
@ -16258,6 +16370,18 @@ struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) {
cur = 0; cur = 0;
} }
work_size = MAX(work_size, cur);
} break;
case GGML_OP_OUT_PROD:
{
n_tasks = n_threads;
size_t cur = 0;
if (ggml_is_quantized(node->src[0]->type)) {
cur = GGML_TYPE_SIZE[GGML_TYPE_F32] * node->src[0]->ne[0] * n_tasks;
}
work_size = MAX(work_size, cur); work_size = MAX(work_size, cur);
} break; } break;
case GGML_OP_SCALE: case GGML_OP_SCALE: