Refactored ggml_cuda_cpy

This commit is contained in:
JohannesGaessler 2023-06-12 09:19:11 +02:00
parent 19c0bf5c86
commit 9a85d913ee

View file

@ -792,21 +792,29 @@ static __global__ void mul_mat_p021_f16_f32(const void * vx, const float * y, fl
}
}
static __global__ void cpy_f32_f16(const float * x, void * vdst, const int ne0, const int ne1, const int stride_1, const int stride_2) {
const int i0 = blockDim.x*blockIdx.x + threadIdx.x;
static __global__ void cpy_f32_f16(const char * cx, char * cdst, const int ne,
const int ne00, const int ne01, const int nb00, const int nb01, const int nb02,
const int ne10, const int ne11, const int nb10, const int nb11, const int nb12) {
const int i = blockDim.x*blockIdx.x + threadIdx.x;
if (i0 >= ne0) {
if (i >= ne) {
return;
}
const int i1 = blockDim.y*blockIdx.y + threadIdx.y;
const int i2 = blockDim.z*blockIdx.z + threadIdx.z;
const int i02 = i / (ne00*ne01);
const int i01 = (i - i02*ne01*ne00) / ne00;
const int i00 = i - i02*ne01*ne00 - i01*ne00;
const int x_offset = i00*nb00 + i01*nb01 + i02*nb02;
half * dst = (half *) vdst;
const int i12 = i / (ne10*ne11);
const int i11 = (i - i12*ne10*ne11) / ne10;
const int i10 = i - i12*ne10*ne11 - i11*ne10;
const int dst_offset = i10*nb10 + i11*nb11 + i12*nb12;
const int ix = i0 + i1*stride_1 + i2*stride_2;
const int idst = i0 + i1*ne0 + i2*ne1*ne0;
dst[idst] = __float2half(x[ix]);
const float * xi = (float *) (cx + x_offset);
half * dsti = (half *) (cdst + dst_offset);
*dsti = __float2half(*xi);
}
static __global__ void rope_f32(const float * x, float * dst, const int ncols, const float p, const float theta_scale) {
@ -1083,12 +1091,14 @@ static void ggml_mul_mat_p021_f16_f32_cuda(const void * vx, const float * y, flo
mul_mat_p021_f16_f32<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols_x, nrows_x, nchannels_x, ncols_y);
}
static void ggml_cpy_f32_f16_cuda(const float * x, void * vdst, const int ne0, const int ne1, const int ne2,
const int stride_1, const int stride_2, cudaStream_t stream) {
const int block_num_x = (ne0 + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
const dim3 block_nums(block_num_x, ne1, ne2);
const dim3 block_dims(CUDA_CPY_BLOCK_SIZE, 1, 1);
cpy_f32_f16<<<block_nums, block_dims, 0, stream>>>(x, vdst, ne0, ne1, stride_1, stride_2);
static void ggml_cpy_f32_f16_cuda(
const char * cx, char * cdst, const int ne,
const int ne00, const int ne01, const int nb00, const int nb01, const int nb02,
const int ne10, const int ne11, const int nb10, const int nb11, const int nb12, cudaStream_t stream) {
const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
cpy_f32_f16<<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>
(cx, cdst, ne, ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12);
}
static void scale_f32_cuda(const float * x, float * dst, const float scale, const int k, cudaStream_t stream) {
@ -2068,35 +2078,43 @@ void ggml_cuda_scale(const ggml_tensor * src0, const ggml_tensor * src1, ggml_te
}
void ggml_cuda_cpy(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
GGML_ASSERT(ggml_nelements(src0) == ggml_nelements(src1));
const int64_t ne = ggml_nelements(src0);
GGML_ASSERT(ne == ggml_nelements(src1));
GGML_ASSERT(ggml_is_contiguous(src0));
GGML_ASSERT(ggml_is_contiguous(src1));
GGML_ASSERT(src0->backend == GGML_BACKEND_GPU);
GGML_ASSERT(src1->backend == GGML_BACKEND_GPU);
GGML_ASSERT(ggml_nbytes(src0) <= INT_MAX);
GGML_ASSERT(ggml_nbytes(src1) <= INT_MAX);
const int64_t ne00 = src0->ne[0];
const int64_t ne01 = src0->ne[1];
const int64_t ne02 = src0->ne[2];
GGML_ASSERT(src0->ne[3] == src1->ne[3]);
GGML_ASSERT(src0->ne[3] == 1);
const int64_t nb00 = src0->nb[0];
const int64_t nb01 = src0->nb[1];
const int64_t nb02 = src0->nb[2];
const int64_t ne10 = src1->ne[0];
const int64_t ne11 = src1->ne[1];
GGML_ASSERT(src1->ne[3] == 1);
const int64_t nb10 = src1->nb[0];
const int64_t nb11 = src1->nb[1];
const int64_t nb12 = src1->nb[2];
cudaStream_t cudaStream_main = g_cudaStreams_main[g_main_device][0];
const struct ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) src0->extra;
const struct ggml_tensor_extra_gpu * src1_extra = (ggml_tensor_extra_gpu *) src1->extra;
float * src0_ddf = (float *) src0_extra->data_device[g_main_device];
void * src1_ddv = src1_extra->data_device[g_main_device];
char * src0_ddc = (char *) src0_extra->data_device[g_main_device];
char * src1_ddc = (char *) src1_extra->data_device[g_main_device];
const int64_t stride_1 = nb01 / sizeof(float);
GGML_ASSERT(nb01 % sizeof(float) == 0);
const int64_t stride_2 = nb02 / sizeof(float);
GGML_ASSERT(nb02 % sizeof(float) == 0);
ggml_cpy_f32_f16_cuda(src0_ddf, src1_ddv, ne00, ne01, ne02, stride_1, stride_2, cudaStream_main);
ggml_cpy_f32_f16_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, nb00, nb01, nb02,
ne10, ne11, nb10, nb11, nb12, cudaStream_main);
// test<<<ggml_nelements(src0), 1, 0, cudaStream_main>>>(src0_ddf, src1_ddv);
CUDA_CHECK(cudaDeviceSynchronize());