ggml_cuda_cpy for f32 -> f32
This commit is contained in:
parent
cf5ae8635a
commit
3b6a2ee414
2 changed files with 48 additions and 9 deletions
54
ggml-cuda.cu
54
ggml-cuda.cu
|
@ -50,6 +50,7 @@ static_assert(sizeof(half) == sizeof(ggml_fp16_t), "wrong fp16 size");
|
|||
typedef void (*dequantize_kernel_t)(const void * vx, const int ib, const int iqs, float & v0, float & v1);
|
||||
typedef void (*to_fp32_cuda_t)(const void * x, float * y, int k, cudaStream_t stream);
|
||||
typedef void (*dot_kernel_k_t)(const void * vx, const int ib, const int iqs, const float * y, float & v);
|
||||
typedef void (*cpy_kernel_t)(const char * cx, char * cdst);
|
||||
typedef void (*ggml_cuda_func_t)(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);
|
||||
typedef void (*ggml_cuda_op_t)(
|
||||
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, char * src0_ddq_i, float * src0_ddf_i,
|
||||
|
@ -838,6 +839,21 @@ static __global__ void mul_mat_vec_nc_f16_f32(
|
|||
}
|
||||
}
|
||||
|
||||
static __device__ void cpy_1_f32_f32(const char * cxi, char * cdsti) {
|
||||
const float * xi = (float *) cxi;
|
||||
float * dsti = (float *) cdsti;
|
||||
|
||||
*dsti = *xi;
|
||||
}
|
||||
|
||||
static __device__ void cpy_1_f32_f16(const char * cxi, char * cdsti) {
|
||||
const float * xi = (float *) cxi;
|
||||
half * dsti = (half *) cdsti;
|
||||
|
||||
*dsti = __float2half(*xi);
|
||||
}
|
||||
|
||||
template <cpy_kernel_t cpy_1>
|
||||
static __global__ void cpy_f32_f16(const char * cx, char * cdst, const int ne,
|
||||
const int ne00, const int ne01, const int nb00, const int nb01, const int nb02,
|
||||
const int ne10, const int ne11, const int nb10, const int nb11, const int nb12) {
|
||||
|
@ -857,10 +873,7 @@ static __global__ void cpy_f32_f16(const char * cx, char * cdst, const int ne,
|
|||
const int i10 = i - i12*ne10*ne11 - i11*ne10;
|
||||
const int dst_offset = i10*nb10 + i11*nb11 + i12*nb12;
|
||||
|
||||
const float * xi = (float *) (cx + x_offset);
|
||||
half * dsti = (half *) (cdst + dst_offset);
|
||||
|
||||
*dsti = __float2half(*xi);
|
||||
cpy_1(cx + x_offset, cdst + dst_offset);
|
||||
}
|
||||
|
||||
static __global__ void rope_f32(const float * x, float * dst, const int ncols, const float p, const float theta_scale) {
|
||||
|
@ -1146,13 +1159,23 @@ static void ggml_mul_mat_vec_nc_f16_f32_cuda(
|
|||
(vx, y, dst, ncols_x, nrows_x, row_stride_x, nchannels_x, channel_stride_x);
|
||||
}
|
||||
|
||||
static void ggml_cpy_f32_f32_cuda(
|
||||
const char * cx, char * cdst, const int ne,
|
||||
const int ne00, const int ne01, const int nb00, const int nb01, const int nb02,
|
||||
const int ne10, const int ne11, const int nb10, const int nb11, const int nb12, cudaStream_t stream) {
|
||||
|
||||
const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
|
||||
cpy_f32_f16<cpy_1_f32_f32><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>
|
||||
(cx, cdst, ne, ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12);
|
||||
}
|
||||
|
||||
static void ggml_cpy_f32_f16_cuda(
|
||||
const char * cx, char * cdst, const int ne,
|
||||
const int ne00, const int ne01, const int nb00, const int nb01, const int nb02,
|
||||
const int ne10, const int ne11, const int nb10, const int nb11, const int nb12, cudaStream_t stream) {
|
||||
|
||||
const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
|
||||
cpy_f32_f16<<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>
|
||||
cpy_f32_f16<cpy_1_f32_f16><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>
|
||||
(cx, cdst, ne, ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12);
|
||||
}
|
||||
|
||||
|
@ -2264,9 +2287,15 @@ void ggml_cuda_cpy(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tens
|
|||
char * src0_ddc = (char *) src0_extra->data_device[g_main_device];
|
||||
char * src1_ddc = (char *) src1_extra->data_device[g_main_device];
|
||||
|
||||
ggml_cpy_f32_f16_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, nb00, nb01, nb02,
|
||||
ne10, ne11, nb10, nb11, nb12, cudaStream_main);
|
||||
// test<<<ggml_nelements(src0), 1, 0, cudaStream_main>>>(src0_ddf, src1_ddv);
|
||||
if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) {
|
||||
ggml_cpy_f32_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, nb00, nb01, nb02,
|
||||
ne10, ne11, nb10, nb11, nb12, cudaStream_main);
|
||||
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16) {
|
||||
ggml_cpy_f32_f16_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, nb00, nb01, nb02,
|
||||
ne10, ne11, nb10, nb11, nb12, cudaStream_main);
|
||||
} else {
|
||||
GGML_ASSERT(false);
|
||||
}
|
||||
|
||||
CUDA_CHECK(cudaDeviceSynchronize());
|
||||
|
||||
|
@ -2362,12 +2391,15 @@ void ggml_cuda_free_data(struct ggml_tensor * tensor) {
|
|||
}
|
||||
|
||||
void ggml_cuda_assign_buffers_impl(struct ggml_tensor * tensor, bool scratch) {
|
||||
if (tensor->src0 != nullptr) {
|
||||
if (tensor->src0 != nullptr && tensor->src0->backend == GGML_BACKEND_CPU) {
|
||||
const ggml_op src0_op = tensor->src0->op;
|
||||
if (src0_op == GGML_OP_RESHAPE || src0_op == GGML_OP_TRANSPOSE || src0_op == GGML_OP_VIEW) {
|
||||
ggml_cuda_assign_buffers_impl(tensor->src0, scratch);
|
||||
}
|
||||
}
|
||||
if (tensor->op == GGML_OP_CPY && tensor->src1->backend == GGML_BACKEND_CPU) {
|
||||
ggml_cuda_assign_buffers_impl(tensor->src1, scratch);
|
||||
}
|
||||
|
||||
tensor->backend = GGML_BACKEND_GPU;
|
||||
struct ggml_tensor_extra_gpu * extra = new ggml_tensor_extra_gpu;
|
||||
|
@ -2385,6 +2417,10 @@ void ggml_cuda_assign_buffers_impl(struct ggml_tensor * tensor, bool scratch) {
|
|||
memcpy(&offset, tensor->opt[0]->data, sizeof(size_t));
|
||||
}
|
||||
extra->data_device[g_main_device] = src0_ddc + offset;
|
||||
} else if (tensor->op == GGML_OP_CPY) {
|
||||
struct ggml_tensor_extra_gpu * src1_extra = (ggml_tensor_extra_gpu * ) tensor->src1->extra;
|
||||
void * src1_ddv = src1_extra->data_device[g_main_device];
|
||||
extra->data_device[g_main_device] = src1_ddv;
|
||||
} else if (scratch) {
|
||||
GGML_ASSERT(size <= g_scratch_size);
|
||||
if (g_scratch_offset + size > g_scratch_size) {
|
||||
|
|
|
@ -1467,6 +1467,7 @@ static bool llama_eval_internal(
|
|||
|
||||
#if 1
|
||||
struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max);
|
||||
offload_func(KQV);
|
||||
ggml_set_name(KQV, "KQV");
|
||||
#else
|
||||
// make V contiguous in memory to speed up the matmul, however we waste time on the copy
|
||||
|
@ -1478,12 +1479,14 @@ static bool llama_eval_internal(
|
|||
|
||||
// KQV_merged = KQV.permute(0, 2, 1, 3)
|
||||
struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
|
||||
offload_func(KQV_merged);
|
||||
ggml_set_name(KQV_merged, "KQV_merged");
|
||||
|
||||
// cur = KQV_merged.contiguous().view(n_embd, N)
|
||||
cur = ggml_cpy(ctx0,
|
||||
KQV_merged,
|
||||
ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, N));
|
||||
offload_func(cur);
|
||||
ggml_set_name(cur, "KQV_merged_contiguous");
|
||||
|
||||
// projection (no bias)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue