CPU/CUDA: fix (GQA) mul mat back, add CUDA support (#11380)
This commit is contained in:
parent
1af6945eb0
commit
8137b4bb2b
7 changed files with 156 additions and 61 deletions
|
@ -7883,7 +7883,7 @@ static void ggml_compute_forward_out_prod_f32(
|
|||
|
||||
float * s0 = (float *) ((char *) src0->data + ( i01*nb01 + i02*nb02 + i03*nb03));
|
||||
float * s1 = (float *) ((char *) src1->data + (i1*nb10 + i11*nb11 + i12*nb12 + i13*nb13));
|
||||
float * d = (float *) ((char *) dst->data + ( i1*nb1 + i2*nb2 + i3*nb3));
|
||||
float * d = (float *) ((char *) dst->data + ( i1*nb1 + i2*nb2 + i3*nb3));
|
||||
|
||||
ggml_vec_mad_f32_unroll(ne0, nb01, nb11, d, s0, s1);
|
||||
}
|
||||
|
@ -7892,7 +7892,7 @@ static void ggml_compute_forward_out_prod_f32(
|
|||
|
||||
float * s0 = (float *) ((char *) src0->data + ( i01*nb01 + i02*nb02 + i03*nb03));
|
||||
float * s1 = (float *) ((char *) src1->data + (i1*nb10 + i11*nb11 + i12*nb12 + i13*nb13));
|
||||
float * d = (float *) ((char *) dst->data + ( i1*nb1 + i2*nb2 + i3*nb3));
|
||||
float * d = (float *) ((char *) dst->data + ( i1*nb1 + i2*nb2 + i3*nb3));
|
||||
|
||||
ggml_vec_mad_f32(ne0, d, s0, *s1);
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue