ggml : sync latest (SAM + SD operators, CUDA alibi) (#2709)
* ggml : sync latest (SAM + SD operators, CUDA alibi) ggml-ci * ggml : fix tabs
This commit is contained in:
parent
8e4364f2af
commit
ef3f333d37
6 changed files with 1090 additions and 61 deletions
79
ggml-cuda.cu
79
ggml-cuda.cu
|
@ -259,6 +259,7 @@ static_assert(sizeof(block_q6_K) == sizeof(ggml_fp16_t) + 13*QK_K/16, "wrong q6_
|
|||
#define CUDA_CPY_BLOCK_SIZE 32
|
||||
#define CUDA_SCALE_BLOCK_SIZE 256
|
||||
#define CUDA_ROPE_BLOCK_SIZE 256
|
||||
#define CUDA_ALIBI_BLOCK_SIZE 32
|
||||
#define CUDA_DIAG_MASK_INF_BLOCK_SIZE 32
|
||||
#define CUDA_QUANTIZE_BLOCK_SIZE 256
|
||||
#define CUDA_DEQUANTIZE_BLOCK_SIZE 256
|
||||
|
@ -3940,6 +3941,29 @@ static __global__ void rope_glm_f32(const float * x, float * dst, const int ncol
|
|||
dst[i + half_n_dims * 3] = x2*sin_block_theta + x3*cos_block_theta;
|
||||
}
|
||||
|
||||
static __global__ void alibi_f32(const float * x, float * dst, const int ncols, const int k_rows,
|
||||
const int n_heads_log2_floor, const float m0, const float m1) {
|
||||
const int col = blockDim.x*blockIdx.x + threadIdx.x;
|
||||
|
||||
if (col >= ncols) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int row = blockDim.y*blockIdx.y + threadIdx.y;
|
||||
const int i = row*ncols + col;
|
||||
|
||||
const int k = row/k_rows;
|
||||
|
||||
float m_k;
|
||||
if (k < n_heads_log2_floor) {
|
||||
m_k = powf(m0, k + 1);
|
||||
} else {
|
||||
m_k = powf(m1, 2 * (k - n_heads_log2_floor) + 1);
|
||||
}
|
||||
|
||||
dst[i] = col * m_k + x[i];
|
||||
}
|
||||
|
||||
static __global__ void diag_mask_inf_f32(const float * x, float * dst, const int ncols, const int rows_per_channel, const int n_past) {
|
||||
const int col = blockDim.x*blockIdx.x + threadIdx.x;
|
||||
const int row = blockDim.y*blockIdx.y + threadIdx.y;
|
||||
|
@ -4766,6 +4790,15 @@ static void rope_glm_f32_cuda(const float * x, float * dst, const int ncols, con
|
|||
rope_glm_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols, p, block_p, theta_scale);
|
||||
}
|
||||
|
||||
static void alibi_f32_cuda(const float * x, float * dst, const int ncols, const int nrows,
|
||||
const int k_rows, const int n_heads_log2_floor, const float m0,
|
||||
const float m1, cudaStream_t stream) {
|
||||
const dim3 block_dims(CUDA_ALIBI_BLOCK_SIZE, 1, 1);
|
||||
const int num_blocks_x = (ncols + CUDA_ALIBI_BLOCK_SIZE - 1) / (CUDA_ALIBI_BLOCK_SIZE);
|
||||
const dim3 block_nums(num_blocks_x, nrows, 1);
|
||||
alibi_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols, k_rows, n_heads_log2_floor, m0, m1);
|
||||
}
|
||||
|
||||
static void diag_mask_inf_f32_cuda(const float * x, float * dst, const int ncols_x, const int nrows_x, const int rows_per_channel, const int n_past, cudaStream_t stream) {
|
||||
const dim3 block_dims(CUDA_DIAG_MASK_INF_BLOCK_SIZE, 1, 1);
|
||||
const int block_num_x = (ncols_x + CUDA_DIAG_MASK_INF_BLOCK_SIZE - 1) / CUDA_DIAG_MASK_INF_BLOCK_SIZE;
|
||||
|
@ -5501,6 +5534,41 @@ inline void ggml_cuda_op_rope(
|
|||
(void) i1;
|
||||
}
|
||||
|
||||
inline void ggml_cuda_op_alibi(
|
||||
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, char * src0_ddq_i,
|
||||
float * src0_ddf_i, float * src1_ddf_i, float * dst_ddf_i, int64_t i02, int64_t i01_low, int64_t i01_high, int i1,
|
||||
cudaStream_t & cudaStream_main){
|
||||
|
||||
GGML_ASSERT(src0_ddf_i != nullptr);
|
||||
GGML_ASSERT(dst_ddf_i != nullptr);
|
||||
|
||||
const int64_t ne00 = src0->ne[0];
|
||||
const int64_t ne01 = src0->ne[1];
|
||||
const int64_t ne02 = src0->ne[2];
|
||||
const int64_t i01_diff = i01_high - i01_low;
|
||||
|
||||
const int n_past = ((int32_t *) dst->op_params)[0];
|
||||
const int n_head = ((int32_t *) dst->op_params)[1];
|
||||
float max_bias;
|
||||
memcpy(&max_bias, (int32_t *) dst->op_params + 2, sizeof(float));
|
||||
|
||||
GGML_ASSERT(ne01 + n_past == ne00);
|
||||
GGML_ASSERT(n_head == ne02);
|
||||
|
||||
const int n_heads_log2_floor = 1 << (int) floor(log2(n_head));
|
||||
|
||||
const float m0 = powf(2.0f, -(max_bias) / n_heads_log2_floor);
|
||||
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_heads_log2_floor);
|
||||
|
||||
// compute
|
||||
alibi_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, ne01, n_heads_log2_floor, m0, m1, cudaStream_main);
|
||||
|
||||
(void) src1;
|
||||
(void) src0_ddq_i;
|
||||
(void) src1_ddf_i;
|
||||
(void) i1;
|
||||
}
|
||||
|
||||
inline void ggml_cuda_op_diag_mask_inf(
|
||||
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, char * src0_ddq_i,
|
||||
float * src0_ddf_i, float * src1_ddf_i, float * dst_ddf_i, int64_t i02, int64_t i01_low, int64_t i01_high, int i1,
|
||||
|
@ -6121,6 +6189,11 @@ void ggml_cuda_rope(const ggml_tensor * src0, const ggml_tensor * src1, ggml_ten
|
|||
ggml_cuda_op(src0, src1, dst, ggml_cuda_op_rope, true, !is_glm); // flatten support not implemented for glm
|
||||
}
|
||||
|
||||
void ggml_cuda_alibi(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32);
|
||||
ggml_cuda_op(src0, src1, dst, ggml_cuda_op_alibi, true, true);
|
||||
}
|
||||
|
||||
void ggml_cuda_nop(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
(void) src0;
|
||||
(void) src1;
|
||||
|
@ -6456,6 +6529,12 @@ bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_
|
|||
}
|
||||
func = ggml_cuda_rope;
|
||||
break;
|
||||
case GGML_OP_ALIBI:
|
||||
if (!any_on_device) {
|
||||
return false;
|
||||
}
|
||||
func = ggml_cuda_alibi;
|
||||
break;
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue