ggml_cuda_cpy for f32 -> f32

This commit is contained in:
JohannesGaessler 2023-06-12 17:33:48 +02:00
parent cf5ae8635a
commit 3b6a2ee414
2 changed files with 48 additions and 9 deletions

View file

@ -50,6 +50,7 @@ static_assert(sizeof(half) == sizeof(ggml_fp16_t), "wrong fp16 size");
typedef void (*dequantize_kernel_t)(const void * vx, const int ib, const int iqs, float & v0, float & v1);
typedef void (*to_fp32_cuda_t)(const void * x, float * y, int k, cudaStream_t stream);
typedef void (*dot_kernel_k_t)(const void * vx, const int ib, const int iqs, const float * y, float & v);
typedef void (*cpy_kernel_t)(const char * cx, char * cdst);
typedef void (*ggml_cuda_func_t)(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);
typedef void (*ggml_cuda_op_t)(
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, char * src0_ddq_i, float * src0_ddf_i,
@ -838,6 +839,21 @@ static __global__ void mul_mat_vec_nc_f16_f32(
}
}
static __device__ void cpy_1_f32_f32(const char * cxi, char * cdsti) {
const float * xi = (float *) cxi;
float * dsti = (float *) cdsti;
*dsti = *xi;
}
static __device__ void cpy_1_f32_f16(const char * cxi, char * cdsti) {
const float * xi = (float *) cxi;
half * dsti = (half *) cdsti;
*dsti = __float2half(*xi);
}
template <cpy_kernel_t cpy_1>
static __global__ void cpy_f32_f16(const char * cx, char * cdst, const int ne,
const int ne00, const int ne01, const int nb00, const int nb01, const int nb02,
const int ne10, const int ne11, const int nb10, const int nb11, const int nb12) {
@ -857,10 +873,7 @@ static __global__ void cpy_f32_f16(const char * cx, char * cdst, const int ne,
const int i10 = i - i12*ne10*ne11 - i11*ne10;
const int dst_offset = i10*nb10 + i11*nb11 + i12*nb12;
const float * xi = (float *) (cx + x_offset);
half * dsti = (half *) (cdst + dst_offset);
*dsti = __float2half(*xi);
cpy_1(cx + x_offset, cdst + dst_offset);
}
static __global__ void rope_f32(const float * x, float * dst, const int ncols, const float p, const float theta_scale) {
@ -1146,13 +1159,23 @@ static void ggml_mul_mat_vec_nc_f16_f32_cuda(
(vx, y, dst, ncols_x, nrows_x, row_stride_x, nchannels_x, channel_stride_x);
}
static void ggml_cpy_f32_f32_cuda(
const char * cx, char * cdst, const int ne,
const int ne00, const int ne01, const int nb00, const int nb01, const int nb02,
const int ne10, const int ne11, const int nb10, const int nb11, const int nb12, cudaStream_t stream) {
const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
cpy_f32_f16<cpy_1_f32_f32><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>
(cx, cdst, ne, ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12);
}
static void ggml_cpy_f32_f16_cuda(
const char * cx, char * cdst, const int ne,
const int ne00, const int ne01, const int nb00, const int nb01, const int nb02,
const int ne10, const int ne11, const int nb10, const int nb11, const int nb12, cudaStream_t stream) {
const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
cpy_f32_f16<<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>
cpy_f32_f16<cpy_1_f32_f16><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>
(cx, cdst, ne, ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12);
}
@ -2264,9 +2287,15 @@ void ggml_cuda_cpy(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tens
char * src0_ddc = (char *) src0_extra->data_device[g_main_device];
char * src1_ddc = (char *) src1_extra->data_device[g_main_device];
ggml_cpy_f32_f16_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, nb00, nb01, nb02,
ne10, ne11, nb10, nb11, nb12, cudaStream_main);
// test<<<ggml_nelements(src0), 1, 0, cudaStream_main>>>(src0_ddf, src1_ddv);
if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) {
ggml_cpy_f32_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, nb00, nb01, nb02,
ne10, ne11, nb10, nb11, nb12, cudaStream_main);
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16) {
ggml_cpy_f32_f16_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, nb00, nb01, nb02,
ne10, ne11, nb10, nb11, nb12, cudaStream_main);
} else {
GGML_ASSERT(false);
}
CUDA_CHECK(cudaDeviceSynchronize());
@ -2362,12 +2391,15 @@ void ggml_cuda_free_data(struct ggml_tensor * tensor) {
}
void ggml_cuda_assign_buffers_impl(struct ggml_tensor * tensor, bool scratch) {
if (tensor->src0 != nullptr) {
if (tensor->src0 != nullptr && tensor->src0->backend == GGML_BACKEND_CPU) {
const ggml_op src0_op = tensor->src0->op;
if (src0_op == GGML_OP_RESHAPE || src0_op == GGML_OP_TRANSPOSE || src0_op == GGML_OP_VIEW) {
ggml_cuda_assign_buffers_impl(tensor->src0, scratch);
}
}
if (tensor->op == GGML_OP_CPY && tensor->src1->backend == GGML_BACKEND_CPU) {
ggml_cuda_assign_buffers_impl(tensor->src1, scratch);
}
tensor->backend = GGML_BACKEND_GPU;
struct ggml_tensor_extra_gpu * extra = new ggml_tensor_extra_gpu;
@ -2385,6 +2417,10 @@ void ggml_cuda_assign_buffers_impl(struct ggml_tensor * tensor, bool scratch) {
memcpy(&offset, tensor->opt[0]->data, sizeof(size_t));
}
extra->data_device[g_main_device] = src0_ddc + offset;
} else if (tensor->op == GGML_OP_CPY) {
struct ggml_tensor_extra_gpu * src1_extra = (ggml_tensor_extra_gpu * ) tensor->src1->extra;
void * src1_ddv = src1_extra->data_device[g_main_device];
extra->data_device[g_main_device] = src1_ddv;
} else if (scratch) {
GGML_ASSERT(size <= g_scratch_size);
if (g_scratch_offset + size > g_scratch_size) {

View file

@ -1467,6 +1467,7 @@ static bool llama_eval_internal(
#if 1
struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max);
offload_func(KQV);
ggml_set_name(KQV, "KQV");
#else
// make V contiguous in memory to speed up the matmul, however we waste time on the copy
@ -1478,12 +1479,14 @@ static bool llama_eval_internal(
// KQV_merged = KQV.permute(0, 2, 1, 3)
struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
offload_func(KQV_merged);
ggml_set_name(KQV_merged, "KQV_merged");
// cur = KQV_merged.contiguous().view(n_embd, N)
cur = ggml_cpy(ctx0,
KQV_merged,
ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, N));
offload_func(cur);
ggml_set_name(cur, "KQV_merged_contiguous");
// projection (no bias)