add phi3 128k support in cuda

This commit is contained in:
Wei Liu 2024-05-02 02:06:20 +08:00 committed by Georgi Gerganov
parent 8fa413d8b5
commit 56d9fa72de
No known key found for this signature in database
GPG key ID: 449E073F9DC10735
3 changed files with 185 additions and 50 deletions

View file

@ -58,10 +58,10 @@ static __global__ void rope(
dst[i + 1] = x0*sin_theta + x1*cos_theta;
}
template<typename T, bool has_pos>
template<typename T, bool has_pos, bool has_freq_facs>
static __global__ void rope_neox(
const T * x, T * dst, int ncols, int n_dims, const int32_t * pos, float freq_scale, int p_delta_rows,
float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, float inv_ndims
float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, float inv_ndims, const float * freq_factors
) {
const int col = 2*(blockDim.y*blockIdx.y + threadIdx.y);
@ -88,7 +88,9 @@ static __global__ void rope_neox(
float cur_rot = inv_ndims * ic - ib;
const int p = has_pos ? pos[i2] : 0;
const float theta_base = p*freq_scale*powf(theta_scale, col/2.0f);
const float freq_factor = has_freq_facs ? freq_factors[col/2] : 1.0f;
const float theta_base = p*freq_scale*powf(theta_scale, col/2.0f)/freq_factor;
float cos_theta, sin_theta;
rope_yarn(theta_base, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor, &cos_theta, &sin_theta);
@ -164,7 +166,7 @@ static void rope_cuda(
template<typename T>
static void rope_neox_cuda(
const T * x, T * dst, int ncols, int n_dims, int nrows, const int32_t * pos, float freq_scale, int p_delta_rows,
float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, cudaStream_t stream
float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float* freq_factors, cudaStream_t stream
) {
GGML_ASSERT(ncols % 2 == 0);
const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
@ -175,15 +177,32 @@ static void rope_neox_cuda(
const float inv_ndims = -1.0f / n_dims;
if (pos == nullptr) {
rope_neox<T, false><<<block_nums, block_dims, 0, stream>>>(
x, dst, ncols, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
theta_scale, inv_ndims
);
} else {
rope_neox<T, true><<<block_nums, block_dims, 0, stream>>>(
x, dst, ncols, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
theta_scale, inv_ndims
);
if (freq_factors == nullptr) {
rope_neox<T, false, false><<<block_nums, block_dims, 0, stream>>>(
x, dst, ncols, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
theta_scale, inv_ndims, freq_factors
);
}
else {
rope_neox<T, false, true><<<block_nums, block_dims, 0, stream>>>(
x, dst, ncols, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
theta_scale, inv_ndims, freq_factors
);
}
}
else {
if (freq_factors == nullptr) {
rope_neox<T, true, false><<<block_nums, block_dims, 0, stream>>>(
x, dst, ncols, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
theta_scale, inv_ndims, freq_factors
);
}
else {
rope_neox<T, true, true><<<block_nums, block_dims, 0, stream>>>(
x, dst, ncols, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
theta_scale, inv_ndims, freq_factors
);
}
}
}
@ -214,17 +233,17 @@ static void rope_cuda_f32(
static void rope_neox_cuda_f16(
const half * x, half * dst, int ncols, int n_dims, int nrows, const int32_t * pos, float freq_scale, int p_delta_rows,
float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, cudaStream_t stream) {
float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float* freq_factors, cudaStream_t stream) {
rope_neox_cuda<half>(x, dst, ncols, n_dims, nrows, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, stream);
rope_neox_cuda<half>(x, dst, ncols, n_dims, nrows, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream);
}
static void rope_neox_cuda_f32(
const float * x, float * dst, int ncols, int n_dims, int nrows, const int32_t * pos, float freq_scale, int p_delta_rows,
float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, cudaStream_t stream
float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float* freq_factors, cudaStream_t stream
) {
rope_neox_cuda<float>(x, dst, ncols, n_dims, nrows, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, stream);
rope_neox_cuda<float>(x, dst, ncols, n_dims, nrows, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream);
}
void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
@ -259,11 +278,18 @@ void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
const float* freq_factors = nullptr;
const int32_t * pos = nullptr;
if ((mode & 1) == 0) {
GGML_ASSERT(src1->type == GGML_TYPE_I32);
GGML_ASSERT(src1->ne[0] == ne2);
pos = (const int32_t *) src1_d;
if (dst->src[2] != nullptr) {
GGML_ASSERT(dst->src[2]->type == GGML_TYPE_F32);
GGML_ASSERT(dst->src[2]->ne[0] >= n_dims / 2);
freq_factors = (const float*) dst->src[2]->data;
}
}
const bool is_neox = mode & 2;
@ -280,12 +306,12 @@ void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
if (src0->type == GGML_TYPE_F32) {
rope_neox_cuda_f32(
(const float *)src0_d, (float *)dst_d, ne00, n_dims, nrows, pos, freq_scale, ne01, freq_base, ext_factor,
attn_factor, corr_dims, stream
attn_factor, corr_dims, freq_factors, stream
);
} else if (src0->type == GGML_TYPE_F16) {
rope_neox_cuda_f16(
(const half *)src0_d, (half *)dst_d, ne00, n_dims, nrows, pos, freq_scale, ne01, freq_base, ext_factor,
attn_factor, corr_dims, stream
attn_factor, corr_dims, freq_factors, stream
);
} else {
GGML_ASSERT(false);

73
ggml.c
View file

@ -6275,6 +6275,13 @@ static struct ggml_tensor * ggml_rope_impl(
return result;
}
struct ggml_tensor * ggml_rope_with_freq_factors(
struct ggml_tensor* rope_tensor,
struct ggml_tensor* freq_factors) {
rope_tensor->src[2] = freq_factors;
return rope_tensor;
}
struct ggml_tensor * ggml_rope(
struct ggml_context * ctx,
struct ggml_tensor * a,
@ -18915,21 +18922,23 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
src0->grad = ggml_add_or_set(ctx,
src0->grad,
ggml_rope_back(ctx,
tensor->grad,
src1,
n_dims,
mode,
n_ctx,
n_orig_ctx,
freq_base,
freq_scale,
ext_factor,
attn_factor,
beta_fast,
beta_slow,
xpos_base,
xpos_down),
ggml_rope_with_freq_factors(
ggml_rope_back(ctx,
tensor->grad,
src1,
n_dims,
mode,
n_ctx,
n_orig_ctx,
freq_base,
freq_scale,
ext_factor,
attn_factor,
beta_fast,
beta_slow,
xpos_base,
xpos_down),
tensor->src[2]),
zero_table);
}
} break;
@ -18954,22 +18963,24 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
src0->grad = ggml_add_or_set(ctx,
src0->grad,
ggml_rope_impl(ctx,
tensor->grad,
src1,
n_dims,
mode,
n_ctx,
n_orig_ctx,
freq_base,
freq_scale,
ext_factor,
attn_factor,
beta_fast,
beta_slow,
xpos_base,
xpos_down,
false),
ggml_rope_with_freq_factors(
ggml_rope_impl(ctx,
tensor->grad,
src1,
n_dims,
mode,
n_ctx,
n_orig_ctx,
freq_base,
freq_scale,
ext_factor,
attn_factor,
beta_fast,
beta_slow,
xpos_base,
xpos_down,
false),
tensor->src[2]),
zero_table);
}
} break;

View file

@ -304,6 +304,9 @@ enum llm_kv {
LLM_KV_ROPE_SCALE_LINEAR,
LLM_KV_ROPE_SCALING_TYPE,
LLM_KV_ROPE_SCALING_FACTOR,
LLM_KV_ROPE_SCALING_LONG_FACTORS,
LLM_KV_ROPE_SCALING_SHORT_FACTORS,
LLM_KV_ROPE_SCALING_ATTN_FACTOR,
LLM_KV_ROPE_SCALING_ORIG_CTX_LEN,
LLM_KV_ROPE_SCALING_FINETUNED,
@ -381,6 +384,9 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
{ LLM_KV_ROPE_SCALE_LINEAR, "%s.rope.scale_linear" },
{ LLM_KV_ROPE_SCALING_TYPE, "%s.rope.scaling.type" },
{ LLM_KV_ROPE_SCALING_FACTOR, "%s.rope.scaling.factor" },
{ LLM_KV_ROPE_SCALING_LONG_FACTORS, "%s.rope.scaling.freq_long_factors" },
{ LLM_KV_ROPE_SCALING_SHORT_FACTORS, "%s.rope.scaling.freq_short_factors" },
{ LLM_KV_ROPE_SCALING_ATTN_FACTOR, "%s.rope.scaling.attn_factor" },
{ LLM_KV_ROPE_SCALING_ORIG_CTX_LEN, "%s.rope.scaling.original_context_length" },
{ LLM_KV_ROPE_SCALING_FINETUNED, "%s.rope.scaling.finetuned" },
@ -1754,6 +1760,10 @@ struct llama_hparams {
float rope_freq_scale_train;
uint32_t n_yarn_orig_ctx;
std::vector<float> rope_long_factors;
std::vector<float> rope_short_factors;
float rope_attn_factor = 1.0f;
// for State Space Models
uint32_t ssm_d_conv = 0;
uint32_t ssm_d_inner = 0;
@ -1789,6 +1799,10 @@ struct llama_hparams {
if (this->rope_finetuned != other.rope_finetuned) return true;
if (this->n_yarn_orig_ctx != other.n_yarn_orig_ctx) return true;
if (this->rope_long_factors != other.rope_long_factors) return true;
if (this->rope_short_factors != other.rope_short_factors) return true;
if (this->rope_attn_factor != other.rope_attn_factor) return true;
if (this->ssm_d_conv != other.ssm_d_conv) return true;
if (this->ssm_d_inner != other.ssm_d_inner) return true;
if (this->ssm_d_state != other.ssm_d_state) return true;
@ -2246,6 +2260,8 @@ struct llama_context {
struct ggml_tensor * inp_s_mask; // F32 [1, n_kv]
struct ggml_tensor * inp_s_seq; // I32 [n_kv, n_batch]
struct ggml_tensor * freq_factors = nullptr; // F32 [kv_size / 2]
// control vectors
struct llama_control_vector cvec;
};
@ -3306,6 +3322,39 @@ struct llama_model_loader {
return get_arr_n(llm_kv(kid), result, required);
}
template<typename T>
bool get_arr(const std::string& key, std::vector<T>& result, const bool required = true) {
const int kid = gguf_find_key(meta, key.c_str());
if (kid < 0) {
if (required) {
throw std::runtime_error(format("key not found in model: %s", key.c_str()));
}
return false;
}
struct GGUFMeta::ArrayInfo arr_info =
GGUFMeta::GKV<GGUFMeta::ArrayInfo>::get_kv(meta, kid);
if (arr_info.gt != GGUF_TYPE_FLOAT32 && arr_info.gt != GGUF_TYPE_INT32) {
throw std::runtime_error(format("%s is not a float32 or int32 array", key.c_str()));
}
// GGML_ASSERT(gguf_type_size(arr_info.gt) == sizeof(T));
GGML_ASSERT((arr_info.gt != GGUF_TYPE_FLOAT32 || std::is_same<T, float>::value));
GGML_ASSERT((arr_info.gt != GGUF_TYPE_INT32 || std::is_same<T, int>::value));
result.resize(arr_info.length);
result.assign((T*)arr_info.data, (T*)arr_info.data + arr_info.length);
return true;
}
template<typename T>
bool get_arr(const enum llm_kv kid, T& result, const bool required = true) {
return get_arr(llm_kv(kid), result, required);
}
template<typename T>
bool get_key(const std::string & key, T & result, const bool required = true) {
auto it = kv_overrides.find(key);
@ -3849,6 +3898,14 @@ static void llm_load_hparams(
}
hparams.rope_freq_scale_train = ropescale == 0.0f ? 1.0f : 1.0f/ropescale;
ml.get_arr(LLM_KV_ROPE_SCALING_LONG_FACTORS, hparams.rope_long_factors, false);
ml.get_arr(LLM_KV_ROPE_SCALING_SHORT_FACTORS, hparams.rope_short_factors, false);
GGML_ASSERT(hparams.rope_long_factors.size() == 0 || hparams.rope_long_factors.size() == hparams.n_embd / hparams.n_head / 2);
GGML_ASSERT(hparams.rope_long_factors.size() == hparams.rope_short_factors.size());
ml.get_key(LLM_KV_ROPE_SCALING_ATTN_FACTOR, hparams.rope_attn_factor, false);
// sanity check for n_rot (optional)
{
hparams.n_rot = (hparams.n_head == 0) ? 0 : hparams.n_embd / hparams.n_head;
@ -6821,6 +6878,8 @@ struct llm_build_context {
cb(lctx.inp_K_shift, "K_shift", -1);
ggml_set_input(lctx.inp_K_shift);
lctx.freq_factors = build_freq_factors();
for (int il = 0; il < n_layer; ++il) {
struct ggml_tensor * tmp =
// we rotate only the first n_rot dimensions
@ -6832,6 +6891,9 @@ struct llm_build_context {
0),
lctx.inp_K_shift, n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow);
tmp = ggml_rope_with_freq_factors(tmp, lctx.freq_factors);
cb(tmp, "K_shifted", il);
ggml_build_forward_expand(gf, tmp);
}
@ -6934,6 +6996,20 @@ struct llm_build_context {
return lctx.inp_pos;
}
struct ggml_tensor* build_freq_factors() {
if (hparams.rope_long_factors.empty() || hparams.rope_short_factors.empty()) {
lctx.freq_factors = nullptr;
return nullptr;
}
lctx.freq_factors = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, n_embd_head_k / 2);
cb(lctx.freq_factors, "freq_factors", -1);
ggml_set_input(lctx.freq_factors);
return lctx.freq_factors;
}
struct ggml_tensor * build_inp_out_ids() {
lctx.inp_out_ids = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_outputs);
cb(lctx.inp_out_ids, "inp_out_ids", -1);
@ -9052,6 +9128,9 @@ struct llm_build_context {
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
// rope freq factors for 128k context
struct ggml_tensor* freq_factors = build_freq_factors();
for (int il = 0; il < n_layer; ++il) {
auto residual = inpL;
@ -9092,6 +9171,7 @@ struct llm_build_context {
ctx0, Qcur, inp_pos, n_rot, rope_type, 0, n_orig_ctx,
freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
);
Qcur = ggml_rope_with_freq_factors(Qcur, freq_factors);
cb(Qcur, "Qcur", il);
Qcur = ggml_scale(ctx0, Qcur, 1.0f / sqrtf(float(n_embd_head)));
@ -9101,6 +9181,7 @@ struct llm_build_context {
ctx0, Kcur, inp_pos, n_rot, rope_type, 0, n_orig_ctx,
freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
);
Kcur = ggml_rope_with_freq_factors(Kcur, freq_factors);
cb(Kcur, "Kcur", il);
cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
@ -10890,6 +10971,22 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
}
}
if (lctx.freq_factors) {
auto freq_dim = hparams.n_embd_head_k / 2;
GGML_ASSERT(lctx.freq_factors->ne[0] == freq_dim);
GGML_ASSERT(hparams.rope_long_factors.size() == freq_dim);
GGML_ASSERT(hparams.rope_short_factors.size() == freq_dim);
auto max_pos = batch.n_tokens > 0 && batch.pos != nullptr ? *std::max_element(batch.pos, batch.pos + batch.n_tokens) : batch.n_tokens - 1;
if (max_pos + 1 > hparams.n_yarn_orig_ctx) {
ggml_backend_tensor_set(lctx.freq_factors, hparams.rope_long_factors.data(), 0, freq_dim * ggml_element_size(lctx.freq_factors));
}
else {
ggml_backend_tensor_set(lctx.freq_factors, hparams.rope_short_factors.data(), 0, freq_dim * ggml_element_size(lctx.freq_factors));
}
}
if (cparams.pooling_type == LLAMA_POOLING_TYPE_MEAN) {
const int64_t n_tokens = batch.n_tokens;
@ -15417,6 +15514,7 @@ struct llama_context * llama_new_context_with_model(
cparams.yarn_ext_factor = rope_scaling_type == LLAMA_ROPE_SCALING_TYPE_YARN ? 1.0f : 0.0f;
}
cparams.yarn_attn_factor *= hparams.rope_attn_factor;
cparams.causal_attn = hparams.causal_attn;
if (cparams.pooling_type == LLAMA_POOLING_TYPE_UNSPECIFIED) {