implement YaRN for GPT-NeoX RoPE

This commit is contained in:
cebtenzzre 2023-11-01 16:44:49 -04:00
parent 9fc823826e
commit 15f26efdb1
3 changed files with 74 additions and 38 deletions

View file

@ -4439,7 +4439,7 @@ static __device__ void rope_yarn(
// rope == RoPE == rotary positional embedding // rope == RoPE == rotary positional embedding
template<typename T, bool has_pos> template<typename T, bool has_pos>
static __global__ void rope( static __global__ void rope(
const T * x, T * dst, int ncols, const int32_t * pos, float freq_scale, int p_delta_rows, float theta_scale, const T * x, T * dst, int ncols, const int32_t * pos, float freq_scale, int p_delta_rows, float freq_base,
float ext_factor, float attn_factor, rope_corr_dims corr_dims float ext_factor, float attn_factor, rope_corr_dims corr_dims
) { ) {
const int col = 2*(blockDim.y*blockIdx.y + threadIdx.y); const int col = 2*(blockDim.y*blockIdx.y + threadIdx.y);
@ -4453,7 +4453,7 @@ static __global__ void rope(
const int i2 = row/p_delta_rows; const int i2 = row/p_delta_rows;
const int p = has_pos ? pos[i2] : 0; const int p = has_pos ? pos[i2] : 0;
const float theta_base = p*powf(theta_scale, col/2); const float theta_base = p*powf(freq_base, -col/ncols);
float cos_theta, sin_theta; float cos_theta, sin_theta;
rope_yarn(theta_base, freq_scale, corr_dims, col, ext_factor, attn_factor, &cos_theta, &sin_theta); rope_yarn(theta_base, freq_scale, corr_dims, col, ext_factor, attn_factor, &cos_theta, &sin_theta);
@ -4466,8 +4466,10 @@ static __global__ void rope(
} }
template<typename T, bool has_pos> template<typename T, bool has_pos>
static __global__ void rope_neox(const T * x, T * dst, const int ncols, const int32_t * pos, const float freq_scale, static __global__ void rope_neox(
const int p_delta_rows, const float theta_scale) { const T * x, T * dst, int ncols, const int32_t * pos, float freq_scale, int p_delta_rows, float freq_base,
float ext_factor, float attn_factor, rope_corr_dims corr_dims
) {
const int col = 2*(blockDim.y*blockIdx.y + threadIdx.y); const int col = 2*(blockDim.y*blockIdx.y + threadIdx.y);
if (col >= ncols) { if (col >= ncols) {
@ -4478,11 +4480,14 @@ static __global__ void rope_neox(const T * x, T * dst, const int ncols, const in
const int i = row*ncols + col/2; const int i = row*ncols + col/2;
const int i2 = row/p_delta_rows; const int i2 = row/p_delta_rows;
// simplified from `(row * ncols + col) * (-1 / ncols)`
const float cur_rot = -col/ncols - row;
const int p = has_pos ? pos[i2] : 0; const int p = has_pos ? pos[i2] : 0;
const float p0 = p*freq_scale; const float theta_base = p*powf(freq_base, cur_rot);
const float theta = p0*powf(theta_scale, col/2);
const float sin_theta = sinf(theta); float cos_theta, sin_theta;
const float cos_theta = cosf(theta); rope_yarn(theta_base, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor, &cos_theta, &sin_theta);
const float x0 = x[i + 0]; const float x0 = x[i + 0];
const float x1 = x[i + ncols/2]; const float x1 = x[i + ncols/2];
@ -4491,8 +4496,10 @@ static __global__ void rope_neox(const T * x, T * dst, const int ncols, const in
dst[i + ncols/2] = x0*sin_theta + x1*cos_theta; dst[i + ncols/2] = x0*sin_theta + x1*cos_theta;
} }
static __global__ void rope_glm_f32(const float * x, float * dst, const int ncols, const int32_t * pos, const float freq_scale, static __global__ void rope_glm_f32(
const int p_delta_rows, const float theta_scale, const int n_ctx) { const float * x, float * dst, int ncols, const int32_t * pos, float freq_scale, int p_delta_rows, float freq_base,
int n_ctx
) {
const int col = blockDim.x*blockIdx.x + threadIdx.x; const int col = blockDim.x*blockIdx.x + threadIdx.x;
const int half_n_dims = ncols/4; const int half_n_dims = ncols/4;
@ -4504,7 +4511,7 @@ static __global__ void rope_glm_f32(const float * x, float * dst, const int ncol
const int i = row*ncols + col; const int i = row*ncols + col;
const int i2 = row/p_delta_rows; const int i2 = row/p_delta_rows;
const float col_theta_scale = powf(theta_scale, col); const float col_theta_scale = powf(freq_base, -2.0f*col/ncols);
// FIXME: this is likely wrong // FIXME: this is likely wrong
const int p = pos != nullptr ? pos[i2] : 0; const int p = pos != nullptr ? pos[i2] : 0;
@ -5525,7 +5532,7 @@ static void clamp_f32_cuda(const float * x, float * dst, const float min, const
template<typename T> template<typename T>
static void rope_cuda( static void rope_cuda(
const T * x, T * dst, int ncols, int nrows, const int32_t * pos, float freq_scale, int p_delta_rows, const T * x, T * dst, int ncols, int nrows, const int32_t * pos, float freq_scale, int p_delta_rows,
float theta_scale, float ext_factor, float attn_factor, rope_corr_dims corr_dims, cudaStream_t stream float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, cudaStream_t stream
) { ) {
GGML_ASSERT(ncols % 2 == 0); GGML_ASSERT(ncols % 2 == 0);
const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1); const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
@ -5533,36 +5540,44 @@ static void rope_cuda(
const dim3 block_nums(nrows, num_blocks_x, 1); const dim3 block_nums(nrows, num_blocks_x, 1);
if (pos == nullptr) { if (pos == nullptr) {
rope<T, false><<<block_nums, block_dims, 0, stream>>>( rope<T, false><<<block_nums, block_dims, 0, stream>>>(
x, dst, ncols, pos, freq_scale, p_delta_rows, theta_scale, ext_factor, attn_factor, corr_dims x, dst, ncols, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims
); );
} else { } else {
rope<T, true><<<block_nums, block_dims, 0, stream>>>( rope<T, true><<<block_nums, block_dims, 0, stream>>>(
x, dst, ncols, pos, freq_scale, p_delta_rows, theta_scale, ext_factor, attn_factor, corr_dims x, dst, ncols, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims
); );
} }
} }
template<typename T> template<typename T>
static void rope_neox_cuda(const T * x, T * dst, const int ncols, const int nrows, const int32_t * pos, const float freq_scale, static void rope_neox_cuda(
const int p_delta_rows, const float theta_scale, cudaStream_t stream) { const T * x, T * dst, int ncols, int nrows, const int32_t * pos, float freq_scale, int p_delta_rows,
float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, cudaStream_t stream
) {
GGML_ASSERT(ncols % 2 == 0); GGML_ASSERT(ncols % 2 == 0);
const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1); const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
const int num_blocks_x = (ncols + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE); const int num_blocks_x = (ncols + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
const dim3 block_nums(nrows, num_blocks_x, 1); const dim3 block_nums(nrows, num_blocks_x, 1);
if (pos == nullptr) { if (pos == nullptr) {
rope_neox<T, false><<<block_nums, block_dims, 0, stream>>>(x, dst, ncols, pos, freq_scale, p_delta_rows, theta_scale); rope_neox<T, false><<<block_nums, block_dims, 0, stream>>>(
x, dst, ncols, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims
);
} else { } else {
rope_neox<T, true><<<block_nums, block_dims, 0, stream>>>(x, dst, ncols, pos, freq_scale, p_delta_rows, theta_scale); rope_neox<T, true><<<block_nums, block_dims, 0, stream>>>(
x, dst, ncols, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims
);
} }
} }
static void rope_glm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const int32_t * pos, const float freq_scale, static void rope_glm_f32_cuda(
const int p_delta_rows, const float theta_scale, const int n_ctx, cudaStream_t stream) { const float * x, float * dst, int ncols, int nrows, const int32_t * pos, float freq_scale, int p_delta_rows,
float freq_base, int n_ctx, cudaStream_t stream
) {
GGML_ASSERT(ncols % 4 == 0); GGML_ASSERT(ncols % 4 == 0);
const dim3 block_dims(CUDA_ROPE_BLOCK_SIZE/4, 1, 1); const dim3 block_dims(CUDA_ROPE_BLOCK_SIZE/4, 1, 1);
const int num_blocks_x = (ncols + CUDA_ROPE_BLOCK_SIZE - 1) / CUDA_ROPE_BLOCK_SIZE; const int num_blocks_x = (ncols + CUDA_ROPE_BLOCK_SIZE - 1) / CUDA_ROPE_BLOCK_SIZE;
const dim3 block_nums(num_blocks_x, nrows, 1); const dim3 block_nums(num_blocks_x, nrows, 1);
rope_glm_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols, pos, freq_scale, p_delta_rows, theta_scale, n_ctx); rope_glm_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols, pos, freq_scale, p_delta_rows, freq_base, n_ctx);
} }
static void alibi_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, static void alibi_f32_cuda(const float * x, float * dst, const int ncols, const int nrows,
@ -6425,8 +6440,6 @@ inline void ggml_cuda_op_rope(
memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float)); memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float)); memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
const float theta_scale = powf(freq_base, -2.0f/n_dims);
const int32_t * pos = nullptr; const int32_t * pos = nullptr;
if ((mode & 1) == 0) { if ((mode & 1) == 0) {
GGML_ASSERT(src1->type == GGML_TYPE_I32); GGML_ASSERT(src1->type == GGML_TYPE_I32);
@ -6437,31 +6450,37 @@ inline void ggml_cuda_op_rope(
const bool is_neox = mode & 2; const bool is_neox = mode & 2;
const bool is_glm = mode & 4; const bool is_glm = mode & 4;
rope_corr_dims corr_dims;
ggml_rope_yarn_corr_dims(n_dims, n_orig_ctx, freq_base, beta_fast, beta_slow, corr_dims.v);
// compute // compute
if (is_glm) { if (is_glm) {
GGML_ASSERT(false); GGML_ASSERT(false);
rope_glm_f32_cuda(src0_dd, dst_dd, ne00, nrows, pos, freq_scale, ne01, theta_scale, n_ctx, main_stream); rope_glm_f32_cuda(src0_dd, dst_dd, ne00, nrows, pos, freq_scale, ne01, freq_base, n_ctx, main_stream);
} else if (is_neox) { } else if (is_neox) {
GGML_ASSERT(ne00 == n_dims && "ne00 != n_dims is not implemented for CUDA yet"); GGML_ASSERT(ne00 == n_dims && "ne00 != n_dims is not implemented for CUDA yet");
if (src0->type == GGML_TYPE_F32) { if (src0->type == GGML_TYPE_F32) {
rope_neox_cuda((const float *)src0_dd, (float *)dst_dd, ne00, nrows, pos, freq_scale, ne01, theta_scale, main_stream); rope_neox_cuda(
(const float *)src0_dd, (float *)dst_dd, ne00, nrows, pos, freq_scale, ne01, freq_base, ext_factor,
attn_factor, corr_dims, main_stream
);
} else if (src0->type == GGML_TYPE_F16) { } else if (src0->type == GGML_TYPE_F16) {
rope_neox_cuda((const half *)src0_dd, (half *)dst_dd, ne00, nrows, pos, freq_scale, ne01, theta_scale, main_stream); rope_neox_cuda(
(const half *)src0_dd, (half *)dst_dd, ne00, nrows, pos, freq_scale, ne01, freq_base, ext_factor,
attn_factor, corr_dims, main_stream
);
} else { } else {
GGML_ASSERT(false); GGML_ASSERT(false);
} }
} else { } else {
rope_corr_dims corr_dims;
ggml_rope_yarn_corr_dims(n_dims, n_orig_ctx, freq_base, beta_fast, beta_slow, corr_dims.v);
if (src0->type == GGML_TYPE_F32) { if (src0->type == GGML_TYPE_F32) {
rope_cuda( rope_cuda(
(const float *)src0_dd, (float *)dst_dd, ne00, nrows, pos, freq_scale, ne01, theta_scale, ext_factor, (const float *)src0_dd, (float *)dst_dd, ne00, nrows, pos, freq_scale, ne01, freq_base, ext_factor,
attn_factor, corr_dims, main_stream attn_factor, corr_dims, main_stream
); );
} else if (src0->type == GGML_TYPE_F16) { } else if (src0->type == GGML_TYPE_F16) {
rope_cuda( rope_cuda(
(const half *)src0_dd, (half *)dst_dd, ne00, nrows, pos, freq_scale, ne01, theta_scale, ext_factor, (const half *)src0_dd, (half *)dst_dd, ne00, nrows, pos, freq_scale, ne01, freq_base, ext_factor,
attn_factor, corr_dims, main_stream attn_factor, corr_dims, main_stream
); );
} else { } else {

View file

@ -1125,9 +1125,12 @@ kernel void kernel_rope(
for (int64_t ib = 0; ib < ne0/n_dims; ++ib) { for (int64_t ib = 0; ib < ne0/n_dims; ++ib) {
for (int64_t ic = 2*tiitg; ic < n_dims; ic += 2*tptg.x) { for (int64_t ic = 2*tiitg; ic < n_dims; ic += 2*tptg.x) {
const float theta = theta_0 * pow(freq_base, inv_ndims*ic - ib); // simplified from `(ib * n_dims + ic) * inv_ndims`
const float cos_theta = cos(theta); const float cur_rot = inv_ndims*ic - ib;
const float sin_theta = sin(theta);
const float theta = theta_0 * pow(freq_base, cur_rot);
float cos_theta, sin_theta;
rope_yarn(theta, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor, &cos_theta, &sin_theta);
const int64_t i0 = ib*n_dims + ic/2; const int64_t i0 = ib*n_dims + ic/2;

22
ggml.c
View file

@ -13486,6 +13486,7 @@ static void ggml_compute_forward_rope_f32(
int ir = 0; int ir = 0;
const float theta_scale = powf(freq_base, -2.0f/n_dims); const float theta_scale = powf(freq_base, -2.0f/n_dims);
const float inv_ndims = -1.f/n_dims;
float corr_dims[2]; float corr_dims[2];
ggml_rope_yarn_corr_dims(n_dims, n_orig_ctx, freq_base, beta_fast, beta_slow, corr_dims); ggml_rope_yarn_corr_dims(n_dims, n_orig_ctx, freq_base, beta_fast, beta_slow, corr_dims);
@ -13556,8 +13557,14 @@ static void ggml_compute_forward_rope_f32(
theta_base *= freq_scale; theta_base *= freq_scale;
for (int64_t ib = 0; ib < ne0/n_dims; ++ib) { for (int64_t ib = 0; ib < ne0/n_dims; ++ib) {
for (int64_t ic = 0; ic < n_dims; ic += 2) { for (int64_t ic = 0; ic < n_dims; ic += 2) {
const float cos_theta = cosf(theta_base); // simplified from `(ib * n_dims + ic) * inv_ndims`
const float sin_theta = sinf(theta_base); float cur_rot = inv_ndims * ic - ib;
float cos_theta, sin_theta;
rope_yarn(
theta_base, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor,
&cos_theta, &sin_theta
);
theta_base *= theta_scale; theta_base *= theta_scale;
@ -13628,6 +13635,7 @@ static void ggml_compute_forward_rope_f16(
int ir = 0; int ir = 0;
const float theta_scale = powf(freq_base, -2.0f/n_dims); const float theta_scale = powf(freq_base, -2.0f/n_dims);
const float inv_ndims = -1.f/n_dims;
float corr_dims[2]; float corr_dims[2];
ggml_rope_yarn_corr_dims(n_dims, n_orig_ctx, freq_base, beta_fast, beta_slow, corr_dims); ggml_rope_yarn_corr_dims(n_dims, n_orig_ctx, freq_base, beta_fast, beta_slow, corr_dims);
@ -13694,8 +13702,14 @@ static void ggml_compute_forward_rope_f16(
theta_base *= freq_scale; theta_base *= freq_scale;
for (int64_t ib = 0; ib < ne0/n_dims; ++ib) { for (int64_t ib = 0; ib < ne0/n_dims; ++ib) {
for (int64_t ic = 0; ic < n_dims; ic += 2) { for (int64_t ic = 0; ic < n_dims; ic += 2) {
const float cos_theta = cosf(theta_base); // simplified from `(ib * n_dims + ic) * inv_ndims`
const float sin_theta = sinf(theta_base); float cur_rot = inv_ndims * ic - ib;
float cos_theta, sin_theta;
rope_yarn(
theta_base, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor,
&cos_theta, &sin_theta
);
theta_base *= theta_scale; theta_base *= theta_scale;