simpler rope implementation
This commit is contained in:
parent
cbe2bac281
commit
aa18b93980
1 changed files with 32 additions and 37 deletions
69
ggml-cuda.cu
69
ggml-cuda.cu
|
@ -4356,15 +4356,8 @@ static __global__ void cpy_f32_f16(const char * cx, char * cdst, const int ne,
|
|||
}
|
||||
|
||||
// rope == RoPE == rotary positional embedding
|
||||
static __global__ void compute_rope_p0(const int32_t * pos, float * p0, int n, int mode, float freq_scale) {
|
||||
int i = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (i < n) {
|
||||
int p = pos[i];
|
||||
p0[i] = (((mode & 1) == 0 ? p : 0)) * freq_scale;
|
||||
}
|
||||
}
|
||||
|
||||
static __global__ void rope_f32(const float * x, float * dst, const int ncols, const float * p0,
|
||||
static __global__ void rope_f32(const float * x, float * dst, const int ncols, const int32_t * pos, const float freq_scale,
|
||||
const int p_delta_rows, const float theta_scale) {
|
||||
const int col = 2*(blockDim.y*blockIdx.y + threadIdx.y);
|
||||
|
||||
|
@ -4376,7 +4369,9 @@ static __global__ void rope_f32(const float * x, float * dst, const int ncols, c
|
|||
const int i = row*ncols + col;
|
||||
const int i2 = row/p_delta_rows;
|
||||
|
||||
const float theta = p0[i2]*powf(theta_scale, col/2);
|
||||
const int p = pos != nullptr ? pos[i2] : 0;
|
||||
const float p0 = p * freq_scale;
|
||||
const float theta = p0*powf(theta_scale, col/2);
|
||||
const float sin_theta = sinf(theta);
|
||||
const float cos_theta = cosf(theta);
|
||||
|
||||
|
@ -4387,8 +4382,8 @@ static __global__ void rope_f32(const float * x, float * dst, const int ncols, c
|
|||
dst[i + 1] = x0*sin_theta + x1*cos_theta;
|
||||
}
|
||||
|
||||
static __global__ void rope_neox_f32(const float * x, float * dst, const int ncols, const float * p0,
|
||||
const int p_delta_rows, const float theta_scale) {
|
||||
static __global__ void rope_neox_f32(const float * x, float * dst, const int ncols, const int32_t * pos, const float freq_scale,
|
||||
const int p_delta_rows, const float theta_scale) {
|
||||
const int col = 2*(blockDim.y*blockIdx.y + threadIdx.y);
|
||||
|
||||
if (col >= ncols) {
|
||||
|
@ -4399,7 +4394,9 @@ static __global__ void rope_neox_f32(const float * x, float * dst, const int nco
|
|||
const int i = row*ncols + col/2;
|
||||
const int i2 = row/p_delta_rows;
|
||||
|
||||
const float theta = p0[i2]*powf(theta_scale, col/2);
|
||||
const int p = pos != nullptr ? pos[i2] : 0;
|
||||
const float p0 = p * freq_scale;
|
||||
const float theta = p0*powf(theta_scale, col/2);
|
||||
const float sin_theta = sinf(theta);
|
||||
const float cos_theta = cosf(theta);
|
||||
|
||||
|
@ -4410,8 +4407,8 @@ static __global__ void rope_neox_f32(const float * x, float * dst, const int nco
|
|||
dst[i + ncols/2] = x0*sin_theta + x1*cos_theta;
|
||||
}
|
||||
|
||||
static __global__ void rope_glm_f32(const float * x, float * dst, const int ncols, const float * p0,
|
||||
const float p_delta, const int p_delta_rows, const float theta_scale, const int n_ctx) {
|
||||
static __global__ void rope_glm_f32(const float * x, float * dst, const int ncols, const int32_t * pos, const float freq_scale,
|
||||
const int p_delta_rows, const float theta_scale, const int n_ctx) {
|
||||
const int col = blockDim.x*blockIdx.x + threadIdx.x;
|
||||
const int half_n_dims = ncols/4;
|
||||
|
||||
|
@ -4425,9 +4422,9 @@ static __global__ void rope_glm_f32(const float * x, float * dst, const int ncol
|
|||
|
||||
const float col_theta_scale = powf(theta_scale, col);
|
||||
// FIXME: this is likely wrong
|
||||
const float p = p0[i2];
|
||||
const int p = pos != nullptr ? pos[i2] : 0;
|
||||
|
||||
const float theta = min(p, p_delta*(n_ctx - 2))*col_theta_scale;
|
||||
const float theta = min(p, n_ctx - 2)*freq_scale*col_theta_scale;
|
||||
const float sin_theta = sinf(theta);
|
||||
const float cos_theta = cosf(theta);
|
||||
|
||||
|
@ -4437,7 +4434,7 @@ static __global__ void rope_glm_f32(const float * x, float * dst, const int ncol
|
|||
dst[i + 0] = x0*cos_theta - x1*sin_theta;
|
||||
dst[i + half_n_dims] = x0*sin_theta + x1*cos_theta;
|
||||
|
||||
const float block_theta = max(p - p_delta*(n_ctx - 2), 0.f)*col_theta_scale;
|
||||
const float block_theta = ((float)max(p - n_ctx - 2, 0))*col_theta_scale;
|
||||
const float sin_block_theta = sinf(block_theta);
|
||||
const float cos_block_theta = cosf(block_theta);
|
||||
|
||||
|
@ -5374,31 +5371,31 @@ static void scale_f32_cuda(const float * x, float * dst, const float scale, cons
|
|||
scale_f32<<<num_blocks, CUDA_SCALE_BLOCK_SIZE, 0, stream>>>(x, dst, scale, k);
|
||||
}
|
||||
|
||||
static void rope_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float * p0,
|
||||
static void rope_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const int32_t * pos, const float freq_scale,
|
||||
const int p_delta_rows, const float theta_scale, cudaStream_t stream) {
|
||||
GGML_ASSERT(ncols % 2 == 0);
|
||||
const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
|
||||
const int num_blocks_x = (ncols + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
|
||||
const dim3 block_nums(nrows, num_blocks_x, 1);
|
||||
rope_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols, p0, p_delta_rows, theta_scale);
|
||||
rope_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols, pos, freq_scale, p_delta_rows, theta_scale);
|
||||
}
|
||||
|
||||
static void rope_neox_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float * p0,
|
||||
static void rope_neox_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const int32_t * pos, const float freq_scale,
|
||||
const int p_delta_rows, const float theta_scale, cudaStream_t stream) {
|
||||
GGML_ASSERT(ncols % 2 == 0);
|
||||
const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
|
||||
const int num_blocks_x = (ncols + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
|
||||
const dim3 block_nums(nrows, num_blocks_x, 1);
|
||||
rope_neox_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols, p0, p_delta_rows, theta_scale);
|
||||
rope_neox_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols, pos, freq_scale, p_delta_rows, theta_scale);
|
||||
}
|
||||
|
||||
static void rope_glm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float * p0,
|
||||
const float p_delta, const int p_delta_rows, const float theta_scale, const int n_ctx, cudaStream_t stream) {
|
||||
static void rope_glm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const int32_t * pos, const float freq_scale,
|
||||
const int p_delta_rows, const float theta_scale, const int n_ctx, cudaStream_t stream) {
|
||||
GGML_ASSERT(ncols % 4 == 0);
|
||||
const dim3 block_dims(CUDA_ROPE_BLOCK_SIZE/4, 1, 1);
|
||||
const int num_blocks_x = (ncols + CUDA_ROPE_BLOCK_SIZE - 1) / CUDA_ROPE_BLOCK_SIZE;
|
||||
const dim3 block_nums(num_blocks_x, nrows, 1);
|
||||
rope_glm_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols, p0, p_delta, p_delta_rows, theta_scale, n_ctx);
|
||||
rope_glm_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols, pos, freq_scale, p_delta_rows, theta_scale, n_ctx);
|
||||
}
|
||||
|
||||
static void alibi_f32_cuda(const float * x, float * dst, const int ncols, const int nrows,
|
||||
|
@ -6105,32 +6102,30 @@ inline void ggml_cuda_op_rope(
|
|||
int id;
|
||||
CUDA_CHECK(cudaGetDevice(&id));
|
||||
|
||||
struct ggml_tensor_extra_gpu * src1_extra = (ggml_tensor_extra_gpu *) src1->extra;
|
||||
if (!src1_extra->copied) {
|
||||
CUDA_CHECK(cudaMemcpyAsync(src1_extra->data_device[id], src1->data, ggml_nbytes(src1), cudaMemcpyHostToDevice, main_stream));
|
||||
src1_extra->copied = true;
|
||||
int * pos = nullptr;
|
||||
if ((mode & 1) == 0) {
|
||||
struct ggml_tensor_extra_gpu * src1_extra = (ggml_tensor_extra_gpu *) src1->extra;
|
||||
pos = (int *) src1_extra->data_device[id];
|
||||
if (!src1_extra->copied) {
|
||||
CUDA_CHECK(cudaMemcpyAsync(pos, src1->data, ggml_nbytes(src1), cudaMemcpyHostToDevice, main_stream));
|
||||
src1_extra->copied = true;
|
||||
}
|
||||
}
|
||||
|
||||
size_t p0d_as;
|
||||
float * p0d = (float *) ggml_cuda_pool_malloc(ne2 * sizeof(float), &p0d_as);
|
||||
compute_rope_p0<<<(ne2 + CUDA_ROPE_BLOCK_SIZE - 1) / CUDA_ROPE_BLOCK_SIZE, CUDA_ROPE_BLOCK_SIZE, 0, main_stream>>>((int32_t*)src1_extra->data_device[id], p0d, ne2, mode, freq_scale);
|
||||
|
||||
const bool is_neox = mode & 2;
|
||||
const bool is_glm = mode & 4;
|
||||
|
||||
// compute
|
||||
if (is_glm) {
|
||||
GGML_ASSERT(false);
|
||||
rope_glm_f32_cuda(src0_dd, dst_dd, ne00, nrows, p0d, freq_scale, ne01, theta_scale, n_ctx, main_stream);
|
||||
rope_glm_f32_cuda(src0_dd, dst_dd, ne00, nrows, pos, freq_scale, ne01, theta_scale, n_ctx, main_stream);
|
||||
} else if (is_neox) {
|
||||
GGML_ASSERT(ne00 == n_dims && "ne00 != n_dims is not implemented for CUDA yet");
|
||||
rope_neox_f32_cuda(src0_dd, dst_dd, ne00, nrows, p0d, ne01, theta_scale, main_stream);
|
||||
rope_neox_f32_cuda(src0_dd, dst_dd, ne00, nrows, pos, freq_scale, ne01, theta_scale, main_stream);
|
||||
} else {
|
||||
rope_f32_cuda(src0_dd, dst_dd, ne00, nrows, p0d, ne01, theta_scale, main_stream);
|
||||
rope_f32_cuda(src0_dd, dst_dd, ne00, nrows, pos, freq_scale, ne01, theta_scale, main_stream);
|
||||
}
|
||||
|
||||
ggml_cuda_pool_free(p0d, p0d_as);
|
||||
|
||||
(void) src1;
|
||||
(void) dst;
|
||||
(void) src1_dd;
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue