ggml_cuda_diag_mask_inf
This commit is contained in:
parent
6b46870fea
commit
8d648a34d8
2 changed files with 138 additions and 1 deletions
137
ggml-cuda.cu
137
ggml-cuda.cu
|
@ -1,5 +1,7 @@
|
|||
#include <climits>
|
||||
#include <cstddef>
|
||||
#include <cstdint>
|
||||
#include <limits>
|
||||
#include <stdint.h>
|
||||
#include <stdio.h>
|
||||
#include <atomic>
|
||||
|
@ -154,6 +156,7 @@ static_assert(sizeof(block_q6_K) == sizeof(ggml_fp16_t) + 13*QK_K/16, "wrong q6_
|
|||
#define CUDA_CPY_BLOCK_SIZE 32
|
||||
#define CUDA_SCALE_BLOCK_SIZE 256
|
||||
#define CUDA_ROPE_BLOCK_SIZE 256
|
||||
#define CUDA_DIAG_MASK_INF_BLOCK_SIZE 32
|
||||
#define CUDA_DEQUANTIZE_BLOCK_SIZE 256
|
||||
|
||||
// dmmv = dequantize_mul_mat_vec
|
||||
|
@ -827,6 +830,58 @@ static __global__ void rope_f32(const float * x, float * dst, const int ncols, c
|
|||
dst[i + 1] = x0*sin_theta + x1*cos_theta;
|
||||
}
|
||||
|
||||
static __global__ void diag_mask_inf_f32(const float * x, float * dst, const int ncols, const int n_past) {
|
||||
const int col = blockDim.x*blockIdx.x + threadIdx.x;
|
||||
const int row = blockDim.y*blockIdx.y + threadIdx.y;
|
||||
|
||||
if (col >= ncols) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int i = row*ncols + col;
|
||||
// dst[i] = col > n_past + row ? -INFINITY : x[i];
|
||||
dst[i] = x[i] - (col > n_past + row) * INT_MAX; // equivalent within rounding error but slightly faster on GPU
|
||||
}
|
||||
|
||||
static __global__ void soft_max_f32(const float * x, float * dst, const int ncols) {
|
||||
const int row = blockDim.y*blockIdx.y + threadIdx.y;
|
||||
const int block_size = blockDim.x;
|
||||
const int tid = threadIdx.x;
|
||||
|
||||
float tmp = 0.0;
|
||||
|
||||
for (int block_start = 0; block_start < ncols; block_start += block_size) {
|
||||
const int col = block_start + tid;
|
||||
|
||||
if (col >= ncols) {
|
||||
break;
|
||||
}
|
||||
|
||||
const int i = row*ncols + col;
|
||||
const float val = expf(x[i]);
|
||||
tmp += val;
|
||||
dst[i] = val;
|
||||
}
|
||||
|
||||
// sum up partial sums
|
||||
__syncthreads();
|
||||
#pragma unroll
|
||||
for (int mask = 16; mask > 0; mask >>= 1) {
|
||||
tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
|
||||
}
|
||||
|
||||
for (int block_start = 0; block_start < ncols; block_start += block_size) {
|
||||
const int col = block_start + tid;
|
||||
|
||||
if (col >= ncols) {
|
||||
break;
|
||||
}
|
||||
|
||||
const int i = row*ncols + col;
|
||||
dst[i] /= tmp;
|
||||
}
|
||||
}
|
||||
|
||||
static __global__ void scale_f32(const float * x, float * dst, const float scale, const int k) {
|
||||
const int i = blockDim.x*blockIdx.x + threadIdx.x;
|
||||
|
||||
|
@ -1049,6 +1104,19 @@ static void rope_f32_cuda(const float * x, float * dst, const int ncols, const i
|
|||
rope_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols, p, theta_scale);
|
||||
}
|
||||
|
||||
static void diag_mask_inf_f32_cuda(const float * x, float * dst, const int ncols_x, const int nrows_x, const int n_past, cudaStream_t stream) {
|
||||
const dim3 block_dims(CUDA_DIAG_MASK_INF_BLOCK_SIZE, 1, 1);
|
||||
const int block_num_x = (ncols_x + CUDA_DIAG_MASK_INF_BLOCK_SIZE - 1) / CUDA_DIAG_MASK_INF_BLOCK_SIZE;
|
||||
const dim3 block_nums(block_num_x, nrows_x, 1);
|
||||
diag_mask_inf_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols_x, n_past);
|
||||
}
|
||||
|
||||
static void soft_max_f32_cuda(const float * x, float * dst, const int ncols_x, const int nrows_x, cudaStream_t stream) {
|
||||
const dim3 block_dims(WARP_SIZE, 1, 1);
|
||||
const dim3 block_nums(1, nrows_x, 1);
|
||||
soft_max_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols_x);
|
||||
}
|
||||
|
||||
// buffer pool for cuda
|
||||
#define MAX_CUDA_BUFFERS 256
|
||||
|
||||
|
@ -1479,6 +1547,53 @@ inline void ggml_cuda_op_rope(
|
|||
(void) i1;
|
||||
}
|
||||
|
||||
inline void ggml_cuda_op_diag_mask_inf(
|
||||
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, char * src0_ddq_i,
|
||||
float * src0_ddf_i, float * src1_ddf_i, float * dst_ddf_i, int64_t i02, int64_t i01_low, int64_t i01_high, int i1,
|
||||
cudaStream_t & cudaStream_main){
|
||||
|
||||
GGML_ASSERT(src0_ddf_i != nullptr);
|
||||
GGML_ASSERT(dst_ddf_i != nullptr);
|
||||
|
||||
const int64_t ne00 = src0->ne[0];
|
||||
const int64_t i01_diff = i01_high - i01_low;
|
||||
|
||||
const int n_past = ((int32_t *) src1->data)[0];
|
||||
|
||||
// compute
|
||||
diag_mask_inf_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, n_past, cudaStream_main);
|
||||
CUDA_CHECK(cudaGetLastError());
|
||||
|
||||
(void) dst;
|
||||
(void) src0_ddq_i;
|
||||
(void) src1_ddf_i;
|
||||
(void) i02;
|
||||
(void) i1;
|
||||
}
|
||||
|
||||
inline void ggml_cuda_op_soft_max(
|
||||
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, char * src0_ddq_i,
|
||||
float * src0_ddf_i, float * src1_ddf_i, float * dst_ddf_i, int64_t i02, int64_t i01_low, int64_t i01_high, int i1,
|
||||
cudaStream_t & cudaStream_main){
|
||||
|
||||
GGML_ASSERT(src0_ddf_i != nullptr);
|
||||
GGML_ASSERT(dst_ddf_i != nullptr);
|
||||
|
||||
const int64_t ne00 = src0->ne[0];
|
||||
const int64_t i01_diff = i01_high - i01_low;
|
||||
|
||||
// compute
|
||||
soft_max_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, cudaStream_main);
|
||||
CUDA_CHECK(cudaGetLastError());
|
||||
|
||||
(void) src1;
|
||||
(void) dst;
|
||||
(void) src0_ddq_i;
|
||||
(void) src1_ddf_i;
|
||||
(void) i02;
|
||||
(void) i1;
|
||||
}
|
||||
|
||||
inline void ggml_cuda_op_scale(
|
||||
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, char * src0_ddq_i,
|
||||
float * src0_ddf_i, float * src1_ddf_i, float * dst_ddf_i, int64_t i02, int64_t i01_low, int64_t i01_high, int i1,
|
||||
|
@ -1970,6 +2085,16 @@ void ggml_cuda_cpy(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tens
|
|||
(void) dst;
|
||||
}
|
||||
|
||||
void ggml_cuda_diag_mask_inf(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32);
|
||||
ggml_cuda_op(src0, src1, dst, ggml_cuda_op_diag_mask_inf, true);
|
||||
}
|
||||
|
||||
void ggml_cuda_soft_max(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32);
|
||||
ggml_cuda_op(src0, src1, dst, ggml_cuda_op_soft_max, true);
|
||||
}
|
||||
|
||||
void ggml_cuda_rope(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32);
|
||||
ggml_cuda_op(src0, src1, dst, ggml_cuda_op_rope, true);
|
||||
|
@ -2185,6 +2310,18 @@ bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_
|
|||
}
|
||||
func = ggml_cuda_nop;
|
||||
break;
|
||||
case GGML_OP_DIAG_MASK_INF:
|
||||
if (!any_on_device) {
|
||||
return false;
|
||||
}
|
||||
func = ggml_cuda_diag_mask_inf;
|
||||
break;
|
||||
case GGML_OP_SOFT_MAX:
|
||||
if (!any_on_device) {
|
||||
return false;
|
||||
}
|
||||
func = ggml_cuda_soft_max;
|
||||
break;
|
||||
case GGML_OP_ROPE:
|
||||
if (!any_on_device) {
|
||||
return false;
|
||||
|
|
|
@ -1434,11 +1434,11 @@ static bool llama_eval_internal(
|
|||
// KQ_scaled shape [n_past + N, N, n_head, 1]
|
||||
struct ggml_tensor * KQ_scaled = ggml_scale_inplace(ctx0, KQ, KQ_scale);
|
||||
offload_func(KQ_scaled);
|
||||
KQ_scaled->backend = GGML_BACKEND_CPU;
|
||||
ggml_set_name(KQ_scaled, "KQ_scaled");
|
||||
|
||||
// KQ_masked = mask_past(KQ_scaled)
|
||||
struct ggml_tensor * KQ_masked = ggml_diag_mask_inf_inplace(ctx0, KQ_scaled, n_past);
|
||||
offload_func(KQ_masked);
|
||||
ggml_set_name(KQ_masked, "KQ_masked");
|
||||
|
||||
// KQ = soft_max(KQ_masked)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue