Modified RoPE with linear scaling

When the context size is greater than the maximum context size
during training, scale the position given to RoPE with
trainign context / n_ctx.
This commit is contained in:
Iwan Kawrakow 2023-06-27 15:00:22 +03:00
parent 0be54f75a6
commit cda30038e4
6 changed files with 34 additions and 3 deletions

View file

@ -72,6 +72,7 @@ set(LLAMA_CUDA_DMMV_X "32" CACHE STRING "llama: x stride for dmmv CUDA kern
set(LLAMA_CUDA_DMMV_Y "1" CACHE STRING "llama: y block size for dmmv CUDA kernels") set(LLAMA_CUDA_DMMV_Y "1" CACHE STRING "llama: y block size for dmmv CUDA kernels")
option(LLAMA_CUDA_DMMV_F16 "llama: use 16 bit floats for dmmv CUDA kernels" OFF) option(LLAMA_CUDA_DMMV_F16 "llama: use 16 bit floats for dmmv CUDA kernels" OFF)
set(LLAMA_CUDA_KQUANTS_ITER "2" CACHE STRING "llama: iters./thread per block for Q2_K/Q6_K") set(LLAMA_CUDA_KQUANTS_ITER "2" CACHE STRING "llama: iters./thread per block for Q2_K/Q6_K")
set(LLAMA_TRAINIG_CTX "2176" CACHE STRING "llama: model training maximum context")
option(LLAMA_CLBLAST "llama: use CLBlast" OFF) option(LLAMA_CLBLAST "llama: use CLBlast" OFF)
option(LLAMA_METAL "llama: use Metal" OFF) option(LLAMA_METAL "llama: use Metal" OFF)
option(LLAMA_K_QUANTS "llama: use k-quants" ON) option(LLAMA_K_QUANTS "llama: use k-quants" ON)
@ -125,6 +126,8 @@ set(CMAKE_C_STANDARD_REQUIRED true)
set(THREADS_PREFER_PTHREAD_FLAG ON) set(THREADS_PREFER_PTHREAD_FLAG ON)
find_package(Threads REQUIRED) find_package(Threads REQUIRED)
add_compile_definitions(GGML_TRAINING_CTX=${LLAMA_TRAINIG_CTX})
if (NOT MSVC) if (NOT MSVC)
if (LLAMA_SANITIZE_THREAD) if (LLAMA_SANITIZE_THREAD)
add_compile_options(-fsanitize=thread) add_compile_options(-fsanitize=thread)

View file

@ -130,6 +130,11 @@ ifneq ($(filter ppc64%,$(UNAME_M)),)
endif endif
endif endif
ifdef LLAMA_TRAINIG_CTX
CFLAGS += -DGGML_TRAINING_CTX=$(LLAMA_TRAINIG_CTX)
CXXFLAGS += -DGGML_TRAINING_CTX=$(LLAMA_TRAINIG_CTX)
endif
ifndef LLAMA_NO_K_QUANTS ifndef LLAMA_NO_K_QUANTS
CFLAGS += -DGGML_USE_K_QUANTS CFLAGS += -DGGML_USE_K_QUANTS
CXXFLAGS += -DGGML_USE_K_QUANTS CXXFLAGS += -DGGML_USE_K_QUANTS

View file

@ -2175,10 +2175,13 @@ inline void ggml_cuda_op_rope(
const int n_past = ((int32_t *) src1->data)[0]; const int n_past = ((int32_t *) src1->data)[0];
const int n_dims = ((int32_t *) src1->data)[1]; const int n_dims = ((int32_t *) src1->data)[1];
const int mode = ((int32_t *) src1->data)[2]; const int mode = ((int32_t *) src1->data)[2];
const int n_ctx = ((int32_t *) src1->data)[3];
GGML_ASSERT(mode == 0); GGML_ASSERT(mode == 0);
const float theta_scale = powf(10000.0, -2.0f/n_dims); const float theta_scale = powf(10000.0, -2.0f/n_dims);
const float p = ((mode & 1) == 0 ? n_past + i02 : i02); const float p0 = ((mode & 1) == 0 ? n_past + i02 : i02);
const float p = n_ctx <= GGML_TRAINING_CTX ? p0 : p0 * GGML_TRAINING_CTX / n_ctx;
// compute // compute
rope_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, p, theta_scale, cudaStream_main); rope_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, p, theta_scale, cudaStream_main);

14
ggml.c
View file

@ -12535,6 +12535,9 @@ static void ggml_compute_forward_rope_f32(
dst_data[n_dims/2*3] = x2*sin_block_theta + x3*cos_block_theta; dst_data[n_dims/2*3] = x2*sin_block_theta + x3*cos_block_theta;
} }
} else if (!is_neox) { } else if (!is_neox) {
if (n_ctx > GGML_TRAINING_CTX) {
theta = theta * GGML_TRAINING_CTX / n_ctx;
}
for (int64_t i0 = 0; i0 < ne0; i0 += 2) { for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
const float cos_theta = cosf(theta); const float cos_theta = cosf(theta);
const float sin_theta = sinf(theta); const float sin_theta = sinf(theta);
@ -12675,6 +12678,9 @@ static void ggml_compute_forward_rope_f16(
dst_data[n_dims/2*3] = GGML_FP32_TO_FP16(x2*sin_block_theta + x3*cos_block_theta); dst_data[n_dims/2*3] = GGML_FP32_TO_FP16(x2*sin_block_theta + x3*cos_block_theta);
} }
} if (!is_neox) { } if (!is_neox) {
if (n_ctx > GGML_TRAINING_CTX) {
theta = theta * GGML_TRAINING_CTX / n_ctx;
}
for (int64_t i0 = 0; i0 < ne0; i0 += 2) { for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
const float cos_theta = cosf(theta); const float cos_theta = cosf(theta);
const float sin_theta = sinf(theta); const float sin_theta = sinf(theta);
@ -12760,6 +12766,7 @@ static void ggml_compute_forward_rope_back_f32(
const int n_past = ((int32_t *) src1->data)[0]; const int n_past = ((int32_t *) src1->data)[0];
const int n_dims = ((int32_t *) src1->data)[1]; const int n_dims = ((int32_t *) src1->data)[1];
const int mode = ((int32_t *) src1->data)[2]; const int mode = ((int32_t *) src1->data)[2];
const int n_ctx = ((int64_t *) src1->data)[3];
assert(n_past >= 0); assert(n_past >= 0);
@ -12813,6 +12820,9 @@ static void ggml_compute_forward_rope_back_f32(
float theta = (float)p; float theta = (float)p;
if (!is_neox) { if (!is_neox) {
if (n_ctx > GGML_TRAINING_CTX) {
theta = theta * GGML_TRAINING_CTX / n_ctx;
}
for (int64_t i0 = 0; i0 < ne0; i0 += 2) { for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
const float cos_theta = cosf(theta); const float cos_theta = cosf(theta);
const float sin_theta = sinf(theta); const float sin_theta = sinf(theta);
@ -12873,6 +12883,7 @@ static void ggml_compute_forward_rope_back_f16(
const int n_past = ((int32_t *) src1->data)[0]; const int n_past = ((int32_t *) src1->data)[0];
const int n_dims = ((int32_t *) src1->data)[1]; const int n_dims = ((int32_t *) src1->data)[1];
const int mode = ((int32_t *) src1->data)[2]; const int mode = ((int32_t *) src1->data)[2];
const int n_ctx = ((int64_t *) src1->data)[3];
assert(n_past >= 0); assert(n_past >= 0);
@ -12926,6 +12937,9 @@ static void ggml_compute_forward_rope_back_f16(
float theta = (float)p; float theta = (float)p;
if (!is_neox) { if (!is_neox) {
if (n_ctx > GGML_TRAINING_CTX) {
theta = theta * GGML_TRAINING_CTX / n_ctx;
}
for (int64_t i0 = 0; i0 < ne0; i0 += 2) { for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
const float cos_theta = cosf(theta); const float cos_theta = cosf(theta);
const float sin_theta = sinf(theta); const float sin_theta = sinf(theta);

6
ggml.h
View file

@ -201,6 +201,12 @@
#define GGML_MAX_NAME 48 #define GGML_MAX_NAME 48
#define GGML_DEFAULT_N_THREADS 4 #define GGML_DEFAULT_N_THREADS 4
// Maximum training context of the model in use
// For the LLaMA models this is normally 2048, but somehow "stepping out" by 128 gives better results (tested at 7B and 13B)
#ifndef GGML_TRAINING_CTX
#define GGML_TRAINING_CTX 2176
#endif
#define GGML_ASSERT(x) \ #define GGML_ASSERT(x) \
do { \ do { \
if (!(x)) { \ if (!(x)) { \

View file

@ -1491,11 +1491,11 @@ static bool llama_eval_internal(
offload_func_kq(tmpq); offload_func_kq(tmpq);
ggml_set_name(tmpq, "tmpq"); ggml_set_name(tmpq, "tmpq");
struct ggml_tensor * Kcur = ggml_rope_inplace(ctx0, ggml_reshape_3d(ctx0, tmpk, n_embd/n_head, n_head, N), n_past, n_rot, 0, 0); struct ggml_tensor * Kcur = ggml_rope_inplace(ctx0, ggml_reshape_3d(ctx0, tmpk, n_embd/n_head, n_head, N), n_past, n_rot, 0, n_ctx);
offload_func_kq(Kcur); offload_func_kq(Kcur);
ggml_set_name(Kcur, "Kcur"); ggml_set_name(Kcur, "Kcur");
struct ggml_tensor * Qcur = ggml_rope_inplace(ctx0, ggml_reshape_3d(ctx0, tmpq, n_embd/n_head, n_head, N), n_past, n_rot, 0, 0); struct ggml_tensor * Qcur = ggml_rope_inplace(ctx0, ggml_reshape_3d(ctx0, tmpq, n_embd/n_head, n_head, N), n_past, n_rot, 0, n_ctx);
offload_func_kq(Qcur); offload_func_kq(Qcur);
ggml_set_name(Qcur, "Qcur"); ggml_set_name(Qcur, "Qcur");