cuda : implement soft_max_ext
This commit is contained in:
parent
e89597c062
commit
88519fbf97
3 changed files with 28 additions and 14 deletions
35
ggml-cuda.cu
35
ggml-cuda.cu
|
@ -4719,16 +4719,18 @@ static __global__ void diag_mask_inf_f32(const float * x, float * dst, const int
|
|||
|
||||
// the CUDA soft max implementation differs from the CPU implementation
|
||||
// instead of doubles floats are used
|
||||
static __global__ void soft_max_f32(const float * x, float * dst, const int ncols) {
|
||||
const int row = blockDim.x*blockIdx.x + threadIdx.x;
|
||||
static __global__ void soft_max_f32(const float * x, const float * y, float * dst, const int ncols, const int nrows_y, const float scale) {
|
||||
const int rowx = blockDim.x*blockIdx.x + threadIdx.x;
|
||||
const int rowy = rowx % nrows_y; // broadcast the mask (y) in the row dimension
|
||||
const int block_size = blockDim.y;
|
||||
const int tid = threadIdx.y;
|
||||
|
||||
float max_val = -INFINITY;
|
||||
|
||||
for (int col = tid; col < ncols; col += block_size) {
|
||||
const int i = row*ncols + col;
|
||||
max_val = max(max_val, x[i]);
|
||||
const int ix = rowx*ncols + col;
|
||||
const int iy = rowy*ncols + col;
|
||||
max_val = max(max_val, x[ix]*scale + (y ? y[iy] : 0.0f));
|
||||
}
|
||||
|
||||
// find the max value in the block
|
||||
|
@ -4740,10 +4742,11 @@ static __global__ void soft_max_f32(const float * x, float * dst, const int ncol
|
|||
float tmp = 0.f;
|
||||
|
||||
for (int col = tid; col < ncols; col += block_size) {
|
||||
const int i = row*ncols + col;
|
||||
const float val = expf(x[i] - max_val);
|
||||
const int ix = rowx*ncols + col;
|
||||
const int iy = rowy*ncols + col;
|
||||
const float val = expf((x[ix]*scale + (y ? y[iy] : 0.0f)) - max_val);
|
||||
tmp += val;
|
||||
dst[i] = val;
|
||||
dst[ix] = val;
|
||||
}
|
||||
|
||||
// sum up partial sums
|
||||
|
@ -4755,7 +4758,7 @@ static __global__ void soft_max_f32(const float * x, float * dst, const int ncol
|
|||
const float inv_tmp = 1.f / tmp;
|
||||
|
||||
for (int col = tid; col < ncols; col += block_size) {
|
||||
const int i = row*ncols + col;
|
||||
const int i = rowx*ncols + col;
|
||||
dst[i] *= inv_tmp;
|
||||
}
|
||||
}
|
||||
|
@ -5792,10 +5795,10 @@ static void diag_mask_inf_f32_cuda(const float * x, float * dst, const int ncols
|
|||
diag_mask_inf_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols_x, rows_per_channel, n_past);
|
||||
}
|
||||
|
||||
static void soft_max_f32_cuda(const float * x, float * dst, const int ncols_x, const int nrows_x, cudaStream_t stream) {
|
||||
static void soft_max_f32_cuda(const float * x, const float * y, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const float scale, cudaStream_t stream) {
|
||||
const dim3 block_dims(1, WARP_SIZE, 1);
|
||||
const dim3 block_nums(nrows_x, 1, 1);
|
||||
soft_max_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols_x);
|
||||
soft_max_f32<<<block_nums, block_dims, 0, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
|
||||
}
|
||||
|
||||
static void im2col_f32_f16_cuda(const float * x, half * dst,
|
||||
|
@ -6846,14 +6849,18 @@ inline void ggml_cuda_op_soft_max(
|
|||
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
||||
|
||||
GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F32); // src1 contains mask and it is optional
|
||||
|
||||
const int64_t ne00 = src0->ne[0];
|
||||
const int64_t nrows = ggml_nrows(src0);
|
||||
const int64_t nrows_x = ggml_nrows(src0);
|
||||
const int64_t nrows_y = src1 ? ggml_nrows(src1) : 0;
|
||||
|
||||
soft_max_f32_cuda(src0_dd, dst_dd, ne00, nrows, main_stream);
|
||||
float scale = 1.0f;
|
||||
memcpy(&scale, dst->op_params, sizeof(float));
|
||||
|
||||
soft_max_f32_cuda(src0_dd, src1 ? src1_dd : nullptr, dst_dd, ne00, nrows_x, nrows_y, scale, main_stream);
|
||||
|
||||
(void) src1;
|
||||
(void) dst;
|
||||
(void) src1_dd;
|
||||
}
|
||||
|
||||
inline void ggml_cuda_op_scale(
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue