cuda : fix binbcast

This commit is contained in:
slaren 2024-04-17 19:28:21 +02:00
parent 997a9b5bd2
commit f7fe79a31d

View file

@ -22,7 +22,7 @@ static __global__ void k_bin_bcast(const src0_t * src0, const src1_t * src1, dst
int ne0, int ne1, int ne2, int ne3, int ne0, int ne1, int ne2, int ne3,
int ne10, int ne11, int ne12, int ne13, int ne10, int ne11, int ne12, int ne13,
/*int s0, */ int s1, int s2, int s3, /*int s0, */ int s1, int s2, int s3,
/*int s01,*/ int s01, int s02, int s03, /*int s00,*/ int s01, int s02, int s03,
/*int s10,*/ int s11, int s12, int s13) { /*int s10,*/ int s11, int s12, int s13) {
const int i0s = blockDim.x*blockIdx.x + threadIdx.x; const int i0s = blockDim.x*blockIdx.x + threadIdx.x;
const int i1 = (blockDim.y*blockIdx.y + threadIdx.y); const int i1 = (blockDim.y*blockIdx.y + threadIdx.y);
@ -56,7 +56,7 @@ static __global__ void k_bin_bcast_unravel(const src0_t * src0, const src1_t * s
int ne0, int ne1, int ne2, int ne3, int ne0, int ne1, int ne2, int ne3,
int ne10, int ne11, int ne12, int ne13, int ne10, int ne11, int ne12, int ne13,
/*int s0, */ int s1, int s2, int s3, /*int s0, */ int s1, int s2, int s3,
/*int s01,*/ int s01, int s02, int s03, /*int s00,*/ int s01, int s02, int s03,
/*int s10,*/ int s11, int s12, int s13) { /*int s10,*/ int s11, int s12, int s13) {
const int i = blockDim.x*blockIdx.x + threadIdx.x; const int i = blockDim.x*blockIdx.x + threadIdx.x;
@ -103,10 +103,14 @@ struct bin_bcast_cuda {
int nr[4] = { nr0, nr1, nr2, nr3 }; int nr[4] = { nr0, nr1, nr2, nr3 };
// collapse dimensions until first broadcast dimension // collapse dimensions until first broadcast dimension
int64_t cne0[] = {ne0, ne1, ne2, ne3}; int64_t cne[] = {ne0, ne1, ne2, ne3};
int64_t cne0[] = {ne00, ne01, ne02, ne03};
int64_t cne1[] = {ne10, ne11, ne12, ne13}; int64_t cne1[] = {ne10, ne11, ne12, ne13};
size_t cnb0[] = {nb0, nb1, nb2, nb3};
size_t cnb[] = {nb0, nb1, nb2, nb3};
size_t cnb0[] = {nb00, nb01, nb02, nb03};
size_t cnb1[] = {nb10, nb11, nb12, nb13}; size_t cnb1[] = {nb10, nb11, nb12, nb13};
auto collapse = [](int64_t cne[]) { auto collapse = [](int64_t cne[]) {
cne[0] *= cne[1]; cne[0] *= cne[1];
cne[1] = cne[2]; cne[1] = cne[2];
@ -126,8 +130,10 @@ struct bin_bcast_cuda {
break; break;
} }
if (i > 0) { if (i > 0) {
collapse_nb(cnb, cne);
collapse_nb(cnb0, cne0); collapse_nb(cnb0, cne0);
collapse_nb(cnb1, cne1); collapse_nb(cnb1, cne1);
collapse(cne);
collapse(cne0); collapse(cne0);
collapse(cne1); collapse(cne1);
} }
@ -135,20 +141,30 @@ struct bin_bcast_cuda {
} }
{ {
int64_t ne0 = cne0[0]; int64_t ne0 = cne[0];
int64_t ne1 = cne0[1]; int64_t ne1 = cne[1];
int64_t ne2 = cne0[2]; int64_t ne2 = cne[2];
int64_t ne3 = cne0[3]; int64_t ne3 = cne[3];
int64_t ne00 = cne0[0];
int64_t ne01 = cne0[1];
int64_t ne02 = cne0[2];
int64_t ne03 = cne0[3];
int64_t ne10 = cne1[0]; int64_t ne10 = cne1[0];
int64_t ne11 = cne1[1]; int64_t ne11 = cne1[1];
int64_t ne12 = cne1[2]; int64_t ne12 = cne1[2];
int64_t ne13 = cne1[3]; int64_t ne13 = cne1[3];
size_t nb0 = cnb0[0]; size_t nb0 = cnb[0];
size_t nb1 = cnb0[1]; size_t nb1 = cnb[1];
size_t nb2 = cnb0[2]; size_t nb2 = cnb[2];
size_t nb3 = cnb0[3]; size_t nb3 = cnb[3];
size_t nb00 = cnb0[0];
size_t nb01 = cnb0[1];
size_t nb02 = cnb0[2];
size_t nb03 = cnb0[3];
size_t nb10 = cnb1[0]; size_t nb10 = cnb1[0];
size_t nb11 = cnb1[1]; size_t nb11 = cnb1[1];