Add "add_f16_f32_f32_cuda"
This commit is contained in:
parent
9587ab4c73
commit
86ceda4275
1 changed files with 17 additions and 1 deletions
18
ggml-cuda.cu
18
ggml-cuda.cu
|
@ -496,6 +496,15 @@ static __global__ void add_f16_f32_f16(const half * x, const float * y, half * d
|
|||
dst[i] = __hadd(x[i], __float2half(y[i]));
|
||||
}
|
||||
|
||||
static __global__ void add_f16_f32_f32(const half * x, const float * y, float * dst, const int k) {
|
||||
const int i = blockDim.x*blockIdx.x + threadIdx.x;
|
||||
|
||||
if (i >= k) {
|
||||
return;
|
||||
}
|
||||
dst[i] = __half2float(x[i]) + y[i];
|
||||
}
|
||||
|
||||
static __global__ void mul_f32(const float * x, const float * y, float * dst, const int kx, const int ky) {
|
||||
const int i = blockDim.x*blockIdx.x + threadIdx.x;
|
||||
|
||||
|
@ -4616,6 +4625,11 @@ static void add_f16_f32_f16_cuda(const half * x, const float * y, half * dst, co
|
|||
add_f16_f32_f16<<<num_blocks, CUDA_ADD_BLOCK_SIZE, 0, stream>>>(x, y, dst, k);
|
||||
}
|
||||
|
||||
static void add_f16_f32_f32_cuda(const half * x, const float * y, float * dst, const int k, cudaStream_t stream) {
|
||||
const int num_blocks = (k + CUDA_ADD_BLOCK_SIZE - 1) / CUDA_ADD_BLOCK_SIZE;
|
||||
add_f16_f32_f32<<<num_blocks, CUDA_ADD_BLOCK_SIZE, 0, stream>>>(x, y, dst, k);
|
||||
}
|
||||
|
||||
static void mul_f32_cuda(const float * x, const float * y, float * dst, const int kx, const int ky, cudaStream_t stream) {
|
||||
const int num_blocks = (kx + CUDA_MUL_BLOCK_SIZE - 1) / CUDA_MUL_BLOCK_SIZE;
|
||||
mul_f32<<<num_blocks, CUDA_MUL_BLOCK_SIZE, 0, stream>>>(x, y, dst, kx, ky);
|
||||
|
@ -5909,8 +5923,10 @@ inline void ggml_cuda_op_add(
|
|||
add_f32_cuda(src0_dd, src1_dd, dst_dd, ggml_nelements(src0), ne10*ne11, main_stream);
|
||||
} else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) {
|
||||
add_f16_f32_f16_cuda((const half *) src0_dd, src1_dd, (half *) dst_dd, ggml_nelements(src0), main_stream);
|
||||
} else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F32) {
|
||||
add_f16_f32_f32_cuda((const half *) src0_dd, src1_dd, dst_dd, ggml_nelements(src0), main_stream);
|
||||
} else {
|
||||
fprintf(stderr, "%d, %d\n", src0->type, dst->type);
|
||||
fprintf(stderr, "src0->type: %d dst->type: %d\n", src0->type, dst->type);
|
||||
GGML_ASSERT(false);
|
||||
}
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue