Compare commits

...
Sign in to create a new pull request.

17 commits

Author SHA1 Message Date
Georgi Gerganov
d2f1e0dacc
Merge branch 'cuda-cublas-opts' into gg/phi-2 2023-12-17 08:41:46 +02:00
Georgi Gerganov
b672c169ca
ggml : fix NeoX rope to rotate just first n_dims 2023-12-17 08:39:18 +02:00
Georgi Gerganov
e75889a9b8
Merge branch 'master' into cuda-cublas-opts
ggml-ci
2023-12-17 08:20:02 +02:00
Georgi Gerganov
0644c3be51
phi-2 : scale Q instead of KQ for better precision 2023-12-16 18:01:08 +02:00
Georgi Gerganov
0b6ffa580c
convert : revert "added_tokens_decoder" change 2023-12-16 16:05:35 +02:00
Georgi Gerganov
a878be4cb1
convert : phi don't add BOS token 2023-12-16 11:20:11 +02:00
Georgi Gerganov
5469d82d5a
llama : fix meta KV override bug 2023-12-16 11:19:56 +02:00
Georgi Gerganov
7500fa2f07
py : whitespaces 2023-12-16 11:01:02 +02:00
Georgi Gerganov
aa5c881adb
phi-2 : use layer norm eps 2023-12-16 10:54:10 +02:00
Georgi Gerganov
a2a3d2c8d7
phi-2 : various fixes 2023-12-16 10:46:18 +02:00
Ebey Abraham
e20765534d fix breaking change 2023-12-16 00:41:06 +00:00
Ebey Abraham
12cc80cb89 phi2 implementation 2023-12-15 20:56:57 +00:00
Georgi Gerganov
66a8dd35a0
Merge branch 'master' into cuda-cublas-opts 2023-12-05 20:54:33 +02:00
Georgi Gerganov
c830a0537b
Merge branch 'master' into cuda-cublas-opts
ggml-ci
2023-11-27 11:54:02 +02:00
Georgi Gerganov
e374227221
Revert "cuda : use CUBLAS_COMPUTE_16F for non-attention ops"
This reverts commit 0f2498f25d.
2023-10-28 12:20:08 +03:00
Georgi Gerganov
0f2498f25d
cuda : use CUBLAS_COMPUTE_16F for non-attention ops 2023-10-27 20:19:42 +03:00
Georgi Gerganov
3b9ea655d4
cuda : use CUBLAS_COMPUTE_32F to speed-up and avoid dst cpy 2023-10-27 18:13:54 +03:00
8 changed files with 388 additions and 80 deletions

View file

@ -182,6 +182,8 @@ class Model:
return QwenModel return QwenModel
if model_architecture == "MixtralForCausalLM": if model_architecture == "MixtralForCausalLM":
return MixtralModel return MixtralModel
if model_architecture == "PhiForCausalLM":
return Phi2Model
return Model return Model
def _is_model_safetensors(self) -> bool: def _is_model_safetensors(self) -> bool:
@ -221,6 +223,8 @@ class Model:
return gguf.MODEL_ARCH.QWEN return gguf.MODEL_ARCH.QWEN
if arch == "MixtralForCausalLM": if arch == "MixtralForCausalLM":
return gguf.MODEL_ARCH.LLAMA return gguf.MODEL_ARCH.LLAMA
if arch == "PhiForCausalLM":
return gguf.MODEL_ARCH.PHI2
raise NotImplementedError(f'Architecture "{arch}" not supported!') raise NotImplementedError(f'Architecture "{arch}" not supported!')
@ -980,6 +984,24 @@ class QwenModel(Model):
print(f"{new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}") print(f"{new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}")
self.gguf_writer.add_tensor(new_name, data) self.gguf_writer.add_tensor(new_name, data)
class Phi2Model(Model):
def set_gguf_parameters(self):
block_count = self.hparams["n_layer"]
self.gguf_writer.add_name("Phi2")
self.gguf_writer.add_context_length(self.hparams["n_positions"])
self.gguf_writer.add_embedding_length(self.hparams["n_embd"])
self.gguf_writer.add_feed_forward_length(4 * self.hparams["n_embd"])
self.gguf_writer.add_block_count(block_count)
self.gguf_writer.add_head_count(self.hparams["n_head"])
self.gguf_writer.add_head_count_kv(self.hparams["n_head"])
self.gguf_writer.add_layer_norm_eps(self.hparams["layer_norm_epsilon"])
self.gguf_writer.add_rope_dimension_count(self.hparams["rotary_dim"])
self.gguf_writer.add_file_type(self.ftype)
self.gguf_writer.add_add_bos_token(False)
###### CONVERSION LOGIC ###### ###### CONVERSION LOGIC ######

View file

@ -4998,22 +4998,29 @@ static __global__ void rope_neox(
const int ib = col / n_dims; const int ib = col / n_dims;
const int ic = col % n_dims; const int ic = col % n_dims;
const int i = row*ncols + ib*n_dims + ic/2; if (ib == 0) {
const int i2 = row/p_delta_rows; const int i = row*ncols + ib*n_dims + ic/2;
const int i2 = row/p_delta_rows;
float cur_rot = inv_ndims * ic - ib; float cur_rot = inv_ndims * ic - ib;
const int p = has_pos ? pos[i2] : 0; const int p = has_pos ? pos[i2] : 0;
const float theta_base = p*freq_scale*powf(theta_scale, col/2.0f); const float theta_base = p*freq_scale*powf(theta_scale, col/2.0f);
float cos_theta, sin_theta; float cos_theta, sin_theta;
rope_yarn(theta_base, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor, &cos_theta, &sin_theta); rope_yarn(theta_base, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor, &cos_theta, &sin_theta);
const float x0 = x[i + 0]; const float x0 = x[i + 0];
const float x1 = x[i + n_dims/2]; const float x1 = x[i + n_dims/2];
dst[i + 0] = x0*cos_theta - x1*sin_theta; dst[i + 0] = x0*cos_theta - x1*sin_theta;
dst[i + n_dims/2] = x0*sin_theta + x1*cos_theta; dst[i + n_dims/2] = x0*sin_theta + x1*cos_theta;
} else {
const int i = row*ncols + ib*n_dims + ic;
dst[i + 0] = x[i + 0];
dst[i + 1] = x[i + 1];
}
} }
static __global__ void rope_glm_f32( static __global__ void rope_glm_f32(
@ -7399,27 +7406,20 @@ inline void ggml_cuda_op_mul_mat_cublas(
to_fp16_cuda(src1_ddf_i, src1_as_f16, ne, stream); to_fp16_cuda(src1_ddf_i, src1_as_f16, ne, stream);
} }
const half * src1_ptr = src1->type == GGML_TYPE_F16 ? (const half *) src1_ddf_i : src1_as_f16; const half * src1_ptr = src1->type == GGML_TYPE_F16 ? (const half *) src1_ddf_i : src1_as_f16;
size_t dst_as = 0;
half * dst_f16 = (half *) ggml_cuda_pool_malloc(row_diff*src1_ncols * sizeof(half), &dst_as);
const half alpha_f16 = 1.0f; const float alpha = 1.0f;
const half beta_f16 = 0.0f; const float beta = 0.0f;
CUBLAS_CHECK(cublasSetStream(g_cublas_handles[id], stream)); CUBLAS_CHECK(cublasSetStream(g_cublas_handles[id], stream));
CUBLAS_CHECK( CUBLAS_CHECK(
cublasGemmEx(g_cublas_handles[id], CUBLAS_OP_T, CUBLAS_OP_N, cublasGemmEx(g_cublas_handles[id], CUBLAS_OP_T, CUBLAS_OP_N,
row_diff, src1_ncols, ne10, row_diff, src1_ncols, ne10,
&alpha_f16, src0_ptr, CUDA_R_16F, ne00, &alpha, src0_ptr, CUDA_R_16F, ne00,
src1_ptr, CUDA_R_16F, ne10, src1_ptr, CUDA_R_16F, ne10,
&beta_f16, dst_f16, CUDA_R_16F, ldc, &beta, dst_dd_i, CUDA_R_32F, ldc,
CUBLAS_COMPUTE_16F, CUBLAS_COMPUTE_32F,
CUBLAS_GEMM_DEFAULT_TENSOR_OP)); CUBLAS_GEMM_DEFAULT_TENSOR_OP));
const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16);
to_fp32_cuda(dst_f16, dst_dd_i, row_diff*src1_ncols, stream);
ggml_cuda_pool_free(dst_f16, dst_as);
if (src0_as != 0) { if (src0_as != 0) {
ggml_cuda_pool_free(src0_as_f16, src0_as); ggml_cuda_pool_free(src0_as_f16, src0_as);
} }
@ -8299,8 +8299,8 @@ static void ggml_cuda_mul_mat_vec_nc(const ggml_tensor * src0, const ggml_tensor
ggml_mul_mat_vec_nc_f16_f32_cuda(src0_ddq, src1_ddf, dst_ddf, ne00, ne01, row_stride_x, ne02, ne12, channel_stride_x, main_stream); ggml_mul_mat_vec_nc_f16_f32_cuda(src0_ddq, src1_ddf, dst_ddf, ne00, ne01, row_stride_x, ne02, ne12, channel_stride_x, main_stream);
} }
static __global__ void k_compute_batched_ptrs( __global__ static void k_compute_batched_ptrs(
const half * src0_as_f16, const half * src1_as_f16, half * dst_f16, const half * src0_as_f16, const half * src1_as_f16, float * dst_f32,
const void ** ptrs_src, void ** ptrs_dst, const void ** ptrs_src, void ** ptrs_dst,
int ne12, int ne13, int ne12, int ne13,
int ne23, int ne23,
@ -8320,7 +8320,7 @@ static __global__ void k_compute_batched_ptrs(
ptrs_src[0*ne23 + i12 + i13*ne12] = (const char *) src0_as_f16 + i02*nb02 + i03*nb03; ptrs_src[0*ne23 + i12 + i13*ne12] = (const char *) src0_as_f16 + i02*nb02 + i03*nb03;
ptrs_src[1*ne23 + i12 + i13*ne12] = (const char *) src1_as_f16 + i12*nb12/2 + i13*nb13/2; ptrs_src[1*ne23 + i12 + i13*ne12] = (const char *) src1_as_f16 + i12*nb12/2 + i13*nb13/2;
ptrs_dst[0*ne23 + i12 + i13*ne12] = ( char *) dst_f16 + i12* nb2/2 + i13* nb3/2; ptrs_dst[0*ne23 + i12 + i13*ne12] = ( char *) dst_f32 + i12* nb2 + i13* nb3 ;
} }
static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
@ -8375,9 +8375,6 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const
half * src1_as_f16 = (half *) ggml_cuda_pool_malloc(ne1 * sizeof(half), &src1_as); half * src1_as_f16 = (half *) ggml_cuda_pool_malloc(ne1 * sizeof(half), &src1_as);
to_fp16_cuda(src1_ddf, src1_as_f16, ne1, main_stream); to_fp16_cuda(src1_ddf, src1_as_f16, ne1, main_stream);
size_t dst_as = 0;
half * dst_f16 = (half *) ggml_cuda_pool_malloc(ne * sizeof(half), &dst_as);
GGML_ASSERT(ne12 % ne02 == 0); GGML_ASSERT(ne12 % ne02 == 0);
GGML_ASSERT(ne13 % ne03 == 0); GGML_ASSERT(ne13 % ne03 == 0);
@ -8385,8 +8382,8 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const
const int64_t r2 = ne12/ne02; const int64_t r2 = ne12/ne02;
const int64_t r3 = ne13/ne03; const int64_t r3 = ne13/ne03;
const half alpha_f16 = 1.0f; const float alpha = 1.0f;
const half beta_f16 = 0.0f; const float beta = 0.0f;
#if 0 #if 0
// use cublasGemmEx // use cublasGemmEx
@ -8399,10 +8396,10 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const
CUBLAS_CHECK( CUBLAS_CHECK(
cublasGemmEx(g_cublas_handles[id], CUBLAS_OP_T, CUBLAS_OP_N, cublasGemmEx(g_cublas_handles[id], CUBLAS_OP_T, CUBLAS_OP_N,
ne01, ne11, ne10, ne01, ne11, ne10,
&alpha_f16, (const char *) src0_as_f16 + i02*src0->nb[2] + i03*src0->nb[3] , CUDA_R_16F, nb01/sizeof(half), &alpha, (const char *) src0_as_f16 + i02*src0->nb[2] + i03*src0->nb[3] , CUDA_R_16F, nb01/sizeof(half),
(const char *) src1_as_f16 + i12*src1->nb[2]/2 + i13*src1->nb[3]/2, CUDA_R_16F, nb11/sizeof(float), (const char *) src1_as_f16 + i12*src1->nb[2]/2 + i13*src1->nb[3]/2, CUDA_R_16F, nb11/sizeof(float),
&beta_f16, ( char *) dst_f16 + i12* dst->nb[2]/2 + i13* dst->nb[3]/2, CUDA_R_16F, ne01, &beta, ( char *) dst_ddf + i12* dst->nb[2] + i13* dst->nb[3] , CUDA_R_32F, ne01,
CUBLAS_COMPUTE_16F, CUBLAS_COMPUTE_32F,
CUBLAS_GEMM_DEFAULT_TENSOR_OP)); CUBLAS_GEMM_DEFAULT_TENSOR_OP));
} }
} }
@ -8414,11 +8411,11 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const
CUBLAS_CHECK( CUBLAS_CHECK(
cublasGemmStridedBatchedEx(g_cublas_handles[g_main_device], CUBLAS_OP_T, CUBLAS_OP_N, cublasGemmStridedBatchedEx(g_cublas_handles[g_main_device], CUBLAS_OP_T, CUBLAS_OP_N,
ne01, ne11, ne10, ne01, ne11, ne10,
&alpha_f16, (const char *) src0_as_f16, CUDA_R_16F, nb01/sizeof(half), src0->nb[2]/sizeof(half), // strideA &alpha, (const char *) src0_as_f16, CUDA_R_16F, nb01/sizeof(half), src0->nb[2]/sizeof(half), // strideA
(const char *) src1_as_f16, CUDA_R_16F, nb11/sizeof(float), src1->nb[2]/sizeof(float), // strideB (const char *) src1_as_f16, CUDA_R_16F, nb11/sizeof(float), src1->nb[2]/sizeof(float), // strideB
&beta_f16, ( char *) dst_f16, CUDA_R_16F, ne01, dst->nb[2]/sizeof(float), // strideC &beta, ( char *) dst_ddf, CUDA_R_32F, ne01, dst->nb[2]/sizeof(float), // strideC
ne12*ne13, ne12*ne13,
CUBLAS_COMPUTE_16F, CUBLAS_COMPUTE_32F,
CUBLAS_GEMM_DEFAULT_TENSOR_OP)); CUBLAS_GEMM_DEFAULT_TENSOR_OP));
} else { } else {
// use cublasGemmBatchedEx // use cublasGemmBatchedEx
@ -8435,7 +8432,7 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const
dim3 block_dims(ne13, ne12); dim3 block_dims(ne13, ne12);
k_compute_batched_ptrs<<<1, block_dims, 0, main_stream>>>( k_compute_batched_ptrs<<<1, block_dims, 0, main_stream>>>(
src0_as_f16, src1_as_f16, dst_f16, src0_as_f16, src1_as_f16, dst_ddf,
ptrs_src, ptrs_dst, ptrs_src, ptrs_dst,
ne12, ne13, ne12, ne13,
ne23, ne23,
@ -8448,11 +8445,11 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const
CUBLAS_CHECK( CUBLAS_CHECK(
cublasGemmBatchedEx(g_cublas_handles[g_main_device], CUBLAS_OP_T, CUBLAS_OP_N, cublasGemmBatchedEx(g_cublas_handles[g_main_device], CUBLAS_OP_T, CUBLAS_OP_N,
ne01, ne11, ne10, ne01, ne11, ne10,
&alpha_f16, (const void **) (ptrs_src + 0*ne23), CUDA_R_16F, nb01/sizeof(half), &alpha, (const void **) (ptrs_src + 0*ne23), CUDA_R_16F, nb01/sizeof(half),
(const void **) (ptrs_src + 1*ne23), CUDA_R_16F, nb11/sizeof(float), (const void **) (ptrs_src + 1*ne23), CUDA_R_16F, nb11/sizeof(float),
&beta_f16, ( void **) (ptrs_dst + 0*ne23), CUDA_R_16F, ne01, &beta, ( void **) (ptrs_dst + 0*ne23), CUDA_R_32F, ne01,
ne23, ne23,
CUBLAS_COMPUTE_16F, CUBLAS_COMPUTE_32F,
CUBLAS_GEMM_DEFAULT_TENSOR_OP)); CUBLAS_GEMM_DEFAULT_TENSOR_OP));
if (ptrs_src_s != 0) { if (ptrs_src_s != 0) {
@ -8464,11 +8461,7 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const
} }
#endif #endif
const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16);
to_fp32_cuda(dst_f16, dst_ddf, ne, main_stream);
ggml_cuda_pool_free(src1_as_f16, src1_as); ggml_cuda_pool_free(src1_as_f16, src1_as);
ggml_cuda_pool_free(dst_f16, dst_as);
} }
static void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { static void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {

View file

@ -1702,8 +1702,9 @@ kernel void kernel_rope(
dst_data[1] = x0*sin_theta + x1*cos_theta; dst_data[1] = x0*sin_theta + x1*cos_theta;
} }
} else { } else {
for (int64_t ib = 0; ib < ne0/n_dims; ++ib) { for (int64_t ic = 2*tiitg; ic < ne0; ic += 2*tptg.x) {
for (int64_t ic = 2*tiitg; ic < n_dims; ic += 2*tptg.x) { if (ic < n_dims) {
const int64_t ib = 0;
// simplified from `(ib * n_dims + ic) * inv_ndims` // simplified from `(ib * n_dims + ic) * inv_ndims`
const float cur_rot = inv_ndims*ic - ib; const float cur_rot = inv_ndims*ic - ib;
@ -1722,6 +1723,14 @@ kernel void kernel_rope(
dst_data[0] = x0*cos_theta - x1*sin_theta; dst_data[0] = x0*cos_theta - x1*sin_theta;
dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta; dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta;
} else {
const int64_t i0 = ic;
device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
dst_data[0] = src[0];
dst_data[1] = src[1];
} }
} }
} }

38
ggml.c
View file

@ -9168,6 +9168,8 @@ static void ggml_compute_forward_norm_f32(
float eps; float eps;
memcpy(&eps, dst->op_params, sizeof(float)); memcpy(&eps, dst->op_params, sizeof(float));
GGML_ASSERT(eps > 0.0f);
// TODO: optimize // TODO: optimize
for (int64_t i03 = 0; i03 < ne03; i03++) { for (int64_t i03 = 0; i03 < ne03; i03++) {
for (int64_t i02 = 0; i02 < ne02; i02++) { for (int64_t i02 = 0; i02 < ne02; i02++) {
@ -9237,6 +9239,8 @@ static void ggml_compute_forward_rms_norm_f32(
float eps; float eps;
memcpy(&eps, dst->op_params, sizeof(float)); memcpy(&eps, dst->op_params, sizeof(float));
GGML_ASSERT(eps > 0.0f);
// TODO: optimize // TODO: optimize
for (int64_t i03 = 0; i03 < ne03; i03++) { for (int64_t i03 = 0; i03 < ne03; i03++) {
for (int64_t i02 = 0; i02 < ne02; i02++) { for (int64_t i02 = 0; i02 < ne02; i02++) {
@ -11562,10 +11566,13 @@ static void ggml_compute_forward_rope_f32(
} }
} else { } else {
// TODO: this might be wrong for ne0 != n_dims - need double check // TODO: this might be wrong for ne0 != n_dims - need double check
// ref: https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt_neox/modeling_gpt_neox.py#LL251C1-L294C28 // it seems we have to rope just the first n_dims elements and do nothing with the rest
// ref: https://github.com/ml-explore/mlx/blob/dc2edc762c797e3b8de50b1dad4dc0a131691033/benchmarks/python/llama_jax_bench.py#L11-L26
theta_base *= freq_scale; theta_base *= freq_scale;
for (int64_t ib = 0; ib < ne0/n_dims; ++ib) { for (int64_t ic = 0; ic < ne0; ic += 2) {
for (int64_t ic = 0; ic < n_dims; ic += 2) { if (ic < n_dims) {
const int64_t ib = 0;
// simplified from `(ib * n_dims + ic) * inv_ndims` // simplified from `(ib * n_dims + ic) * inv_ndims`
float cur_rot = inv_ndims * ic - ib; float cur_rot = inv_ndims * ic - ib;
@ -11588,6 +11595,14 @@ static void ggml_compute_forward_rope_f32(
dst_data[0] = x0*cos_theta - x1*sin_theta; dst_data[0] = x0*cos_theta - x1*sin_theta;
dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta; dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta;
} else {
const int64_t i0 = ic;
const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
dst_data[0] = src[0];
dst_data[1] = src[1];
} }
} }
} }
@ -11715,10 +11730,13 @@ static void ggml_compute_forward_rope_f16(
} }
} else { } else {
// TODO: this might be wrong for ne0 != n_dims - need double check // TODO: this might be wrong for ne0 != n_dims - need double check
// ref: https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt_neox/modeling_gpt_neox.py#LL251C1-L294C28 // it seems we have to rope just the first n_dims elements and do nothing with the rest
// ref: https://github.com/ml-explore/mlx/blob/dc2edc762c797e3b8de50b1dad4dc0a131691033/benchmarks/python/llama_jax_bench.py#L11-L26
theta_base *= freq_scale; theta_base *= freq_scale;
for (int64_t ib = 0; ib < ne0/n_dims; ++ib) { for (int64_t ic = 0; ic < ne0; ic += 2) {
for (int64_t ic = 0; ic < n_dims; ic += 2) { if (ic < n_dims) {
const int64_t ib = 0;
// simplified from `(ib * n_dims + ic) * inv_ndims` // simplified from `(ib * n_dims + ic) * inv_ndims`
float cur_rot = inv_ndims * ic - ib; float cur_rot = inv_ndims * ic - ib;
@ -11741,6 +11759,14 @@ static void ggml_compute_forward_rope_f16(
dst_data[0] = GGML_FP32_TO_FP16(x0*cos_theta - x1*sin_theta); dst_data[0] = GGML_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
dst_data[n_dims/2] = GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta); dst_data[n_dims/2] = GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
} else {
const int64_t i0 = ic;
const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
dst_data[0] = src[0];
dst_data[1] = src[1];
} }
} }
} }

View file

@ -95,6 +95,7 @@ class MODEL_ARCH(IntEnum):
BLOOM = auto() BLOOM = auto()
STABLELM = auto() STABLELM = auto()
QWEN = auto() QWEN = auto()
PHI2 = auto()
class MODEL_TENSOR(IntEnum): class MODEL_TENSOR(IntEnum):
@ -140,6 +141,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
MODEL_ARCH.BLOOM: "bloom", MODEL_ARCH.BLOOM: "bloom",
MODEL_ARCH.STABLELM: "stablelm", MODEL_ARCH.STABLELM: "stablelm",
MODEL_ARCH.QWEN: "qwen", MODEL_ARCH.QWEN: "qwen",
MODEL_ARCH.PHI2: "phi2",
} }
TENSOR_NAMES: dict[MODEL_TENSOR, str] = { TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
@ -350,6 +352,17 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
MODEL_ARCH.GPT2: [ MODEL_ARCH.GPT2: [
# TODO # TODO
], ],
MODEL_ARCH.PHI2: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT_NORM,
MODEL_TENSOR.OUTPUT,
MODEL_TENSOR.ATTN_NORM,
MODEL_TENSOR.ATTN_QKV,
MODEL_TENSOR.ATTN_OUT,
MODEL_TENSOR.FFN_NORM,
MODEL_TENSOR.FFN_DOWN,
MODEL_TENSOR.FFN_UP,
]
# TODO # TODO
} }

View file

@ -17,6 +17,7 @@ class TensorNameMap:
"tok_embeddings", # llama-pth "tok_embeddings", # llama-pth
"embeddings.word_embeddings", # bert "embeddings.word_embeddings", # bert
"language_model.embedding.word_embeddings", # persimmon "language_model.embedding.word_embeddings", # persimmon
"transformer.embd.wte", # phi2
), ),
# Token type embeddings # Token type embeddings
@ -41,6 +42,7 @@ class TensorNameMap:
"lm_head", # gpt2 mpt falcon llama-hf baichuan qwen "lm_head", # gpt2 mpt falcon llama-hf baichuan qwen
"output", # llama-pth bloom "output", # llama-pth bloom
"word_embeddings_for_head", # persimmon "word_embeddings_for_head", # persimmon
"lm_head.linear", # phi2
), ),
# Output norm # Output norm
@ -53,6 +55,7 @@ class TensorNameMap:
"transformer.norm_f", # mpt "transformer.norm_f", # mpt
"ln_f", # refact bloom qwen "ln_f", # refact bloom qwen
"language_model.encoder.final_layernorm", # persimmon "language_model.encoder.final_layernorm", # persimmon
"lm_head.ln", # phi2
), ),
# Rope frequencies # Rope frequencies
@ -75,6 +78,7 @@ class TensorNameMap:
"encoder.layer.{bid}.attention.output.LayerNorm", # bert "encoder.layer.{bid}.attention.output.LayerNorm", # bert
"language_model.encoder.layers.{bid}.input_layernorm", # persimmon "language_model.encoder.layers.{bid}.input_layernorm", # persimmon
"model.layers.{bid}.ln1", # yi "model.layers.{bid}.ln1", # yi
"transformer.h.{bid}.ln", # phi2
), ),
# Attention norm 2 # Attention norm 2
@ -90,6 +94,7 @@ class TensorNameMap:
"transformer.h.{bid}.self_attention.query_key_value", # falcon "transformer.h.{bid}.self_attention.query_key_value", # falcon
"h.{bid}.self_attention.query_key_value", # bloom "h.{bid}.self_attention.query_key_value", # bloom
"language_model.encoder.layers.{bid}.self_attention.query_key_value", # persimmon "language_model.encoder.layers.{bid}.self_attention.query_key_value", # persimmon
"transformer.h.{bid}.mixer.Wqkv", # phi2
), ),
# Attention query # Attention query
@ -128,6 +133,7 @@ class TensorNameMap:
"encoder.layer.{bid}.attention.output.dense", # bert "encoder.layer.{bid}.attention.output.dense", # bert
"transformer.h.{bid}.attn.out_proj", # gpt-j "transformer.h.{bid}.attn.out_proj", # gpt-j
"language_model.encoder.layers.{bid}.self_attention.dense", # persimmon "language_model.encoder.layers.{bid}.self_attention.dense", # persimmon
"transformer.h.{bid}.mixer.out_proj", # phi2
), ),
# Rotary embeddings # Rotary embeddings
@ -167,6 +173,7 @@ class TensorNameMap:
"transformer.h.{bid}.mlp.fc_in", # gpt-j "transformer.h.{bid}.mlp.fc_in", # gpt-j
"language_model.encoder.layers.{bid}.mlp.dense_h_to_4h", # persimmon "language_model.encoder.layers.{bid}.mlp.dense_h_to_4h", # persimmon
"transformer.h.{bid}.mlp.w1", # qwen "transformer.h.{bid}.mlp.w1", # qwen
"transformer.h.{bid}.mlp.fc1", # phi2
), ),
MODEL_TENSOR.FFN_UP_EXP: ( MODEL_TENSOR.FFN_UP_EXP: (
@ -198,6 +205,7 @@ class TensorNameMap:
"encoder.layer.{bid}.output.dense", # bert "encoder.layer.{bid}.output.dense", # bert
"transformer.h.{bid}.mlp.fc_out", # gpt-j "transformer.h.{bid}.mlp.fc_out", # gpt-j
"language_model.encoder.layers.{bid}.mlp.dense_4h_to_h", # persimmon "language_model.encoder.layers.{bid}.mlp.dense_4h_to_h", # persimmon
"transformer.h.{bid}.mlp.fc2", # phi2
), ),
MODEL_TENSOR.FFN_DOWN_EXP: ( MODEL_TENSOR.FFN_DOWN_EXP: (

282
llama.cpp
View file

@ -195,6 +195,7 @@ enum llm_arch {
LLM_ARCH_BLOOM, LLM_ARCH_BLOOM,
LLM_ARCH_STABLELM, LLM_ARCH_STABLELM,
LLM_ARCH_QWEN, LLM_ARCH_QWEN,
LLM_ARCH_PHI2,
LLM_ARCH_UNKNOWN, LLM_ARCH_UNKNOWN,
}; };
@ -212,6 +213,7 @@ static std::map<llm_arch, std::string> LLM_ARCH_NAMES = {
{ LLM_ARCH_BLOOM, "bloom" }, { LLM_ARCH_BLOOM, "bloom" },
{ LLM_ARCH_STABLELM, "stablelm" }, { LLM_ARCH_STABLELM, "stablelm" },
{ LLM_ARCH_QWEN, "qwen" }, { LLM_ARCH_QWEN, "qwen" },
{ LLM_ARCH_PHI2, "phi2" },
}; };
enum llm_kv { enum llm_kv {
@ -550,6 +552,19 @@ static std::map<llm_arch, std::map<llm_tensor, std::string>> LLM_TENSOR_NAMES =
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
}, },
}, },
{
LLM_ARCH_PHI2,
{
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
{ LLM_TENSOR_OUTPUT, "output" },
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
{ LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" },
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
},
},
{ {
LLM_ARCH_UNKNOWN, LLM_ARCH_UNKNOWN,
@ -1420,6 +1435,7 @@ struct llama_model {
struct ggml_tensor * output_norm; struct ggml_tensor * output_norm;
struct ggml_tensor * output_norm_b; struct ggml_tensor * output_norm_b;
struct ggml_tensor * output; struct ggml_tensor * output;
struct ggml_tensor * output_b;
std::vector<llama_layer> layers; std::vector<llama_layer> layers;
@ -1937,7 +1953,7 @@ namespace GGUFMeta {
target = override->bool_value; target = override->bool_value;
return true; return true;
} }
return true; return false;
} }
template<typename OT> template<typename OT>
@ -2634,6 +2650,15 @@ static void llm_load_hparams(
default: model.type = e_model::MODEL_UNKNOWN; default: model.type = e_model::MODEL_UNKNOWN;
} }
} break; } break;
case LLM_ARCH_PHI2:
{
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
switch (hparams.n_layer) {
case 32: model.type = e_model::MODEL_3B; break;
default: model.type = e_model::MODEL_UNKNOWN;
}
} break;
default: (void)0; default: (void)0;
} }
@ -2986,7 +3011,7 @@ static void llm_load_tensors(
(void) main_gpu; (void) main_gpu;
enum ggml_backend_type llama_backend_offload = GGML_BACKEND_CPU; enum ggml_backend_type llama_backend_offload = GGML_BACKEND_CPU;
enum ggml_backend_type llama_backend_offload_split = GGML_BACKEND_CPU; enum ggml_backend_type llama_backend_offload_split = GGML_BACKEND_CPU;
#ifdef GGML_USE_CUBLAS #ifdef GGML_USE_CUBLAS
@ -3629,7 +3654,73 @@ static void llm_load_tensors(
} }
} }
} break; } break;
case LLM_ARCH_PHI2:
{
model.tok_embd = ml.create_tensor(ctx, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, GGML_BACKEND_CPU);
// output
{
ggml_backend_type backend_norm;
ggml_backend_type backend_output;
if (n_gpu_layers > int(n_layer)) {
backend_norm = llama_backend_offload;
backend_output = llama_backend_offload;
} else {
backend_norm = GGML_BACKEND_CPU;
backend_output = GGML_BACKEND_CPU;
}
model.output_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, backend_norm);
model.output_norm_b = ml.create_tensor(ctx, tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, backend_norm);
model.output = ml.create_tensor(ctx, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, backend_output);
model.output_b = ml.create_tensor(ctx, tn(LLM_TENSOR_OUTPUT, "bias"), {n_vocab}, backend_output);
if (backend_norm == GGML_BACKEND_GPU) {
vram_weights += ggml_nbytes(model.output_norm);
vram_weights += ggml_nbytes(model.output_norm_b);
vram_weights += ggml_nbytes(model.output);
vram_weights += ggml_nbytes(model.output_b);
}
}
const uint32_t n_ff = hparams.n_ff;
const int i_gpu_start = n_layer - n_gpu_layers;
model.layers.resize(n_layer);
for (uint32_t i = 0; i < n_layer; ++i) {
const ggml_backend_type backend = int(i) < i_gpu_start ? GGML_BACKEND_CPU : llama_backend_offload; // NOLINT
const ggml_backend_type backend_split = int(i) < i_gpu_start ? GGML_BACKEND_CPU : llama_backend_offload_split; // NOLINT
auto & layer = model.layers[i];
layer.attn_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, backend);
layer.attn_norm_b = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, backend);
layer.wqkv = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, backend_split);
layer.bqkv = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, backend);
layer.wo = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, backend_split);
layer.bo = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, backend);
layer.ffn_down = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, backend_split);
layer.ffn_down_b = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, backend);
layer.ffn_up = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, backend_split);
layer.ffn_up_b = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, backend);
if (backend == GGML_BACKEND_GPU) {
vram_weights +=
ggml_nbytes(layer.attn_norm) + ggml_nbytes(layer.attn_norm_b) +
ggml_nbytes(layer.wqkv) + ggml_nbytes(layer.bqkv) +
ggml_nbytes(layer.wo) + ggml_nbytes(layer.bo) +
ggml_nbytes(layer.ffn_up) + ggml_nbytes(layer.ffn_up_b) +
ggml_nbytes(layer.ffn_down) + ggml_nbytes(layer.ffn_down_b);
}
}
} break;
default: default:
throw std::runtime_error("unknown architecture"); throw std::runtime_error("unknown architecture");
} }
@ -4001,6 +4092,7 @@ static struct ggml_tensor * llm_build_kqv(
int32_t n_tokens, int32_t n_tokens,
int32_t n_kv, int32_t n_kv,
float max_alibi_bias, float max_alibi_bias,
float scale,
const llm_build_cb & cb, const llm_build_cb & cb,
int il) { int il) {
const int64_t n_embd = hparams.n_embd; const int64_t n_embd = hparams.n_embd;
@ -4042,7 +4134,7 @@ static struct ggml_tensor * llm_build_kqv(
kq = ggml_soft_max(ctx, kq); kq = ggml_soft_max(ctx, kq);
cb(kq, "kq_soft_max", il); cb(kq, "kq_soft_max", il);
} else { } else {
kq = ggml_soft_max_ext(ctx, kq, kq_mask, 1.0f/sqrtf(float(n_embd_head))); kq = ggml_soft_max_ext(ctx, kq, kq_mask, scale);
cb(kq, "kq_soft_max_ext", il); cb(kq, "kq_soft_max_ext", il);
} }
@ -4251,7 +4343,7 @@ struct llm_build_context {
cur = llm_build_kqv(ctx0, hparams, kv_self, cur = llm_build_kqv(ctx0, hparams, kv_self,
model.layers[il].wo, model.layers[il].bo, model.layers[il].wo, model.layers[il].bo,
Qcur, KQ_scale, KQ_mask, n_ctx, n_tokens, n_kv, -1.0f, cb, il); Qcur, KQ_scale, KQ_mask, n_ctx, n_tokens, n_kv, -1.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il);
cb(cur, "kqv_out", il); cb(cur, "kqv_out", il);
} }
@ -4434,7 +4526,7 @@ struct llm_build_context {
cur = llm_build_kqv(ctx0, hparams, kv_self, cur = llm_build_kqv(ctx0, hparams, kv_self,
model.layers[il].wo, NULL, model.layers[il].wo, NULL,
Qcur, KQ_scale, KQ_mask, n_ctx, n_tokens, n_kv, max_alibi_bias, cb, il); Qcur, KQ_scale, KQ_mask, n_ctx, n_tokens, n_kv, max_alibi_bias, 1.0f/sqrtf(float(n_embd_head)), cb, il);
cb(cur, "kqv_out", il); cb(cur, "kqv_out", il);
} }
@ -4558,7 +4650,7 @@ struct llm_build_context {
cur = llm_build_kqv(ctx0, hparams, kv_self, cur = llm_build_kqv(ctx0, hparams, kv_self,
model.layers[il].wo, NULL, model.layers[il].wo, NULL,
Qcur, KQ_scale, KQ_mask, n_ctx, n_tokens, n_kv, -1.0f, cb, il); Qcur, KQ_scale, KQ_mask, n_ctx, n_tokens, n_kv, -1.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il);
cb(cur, "kqv_out", il); cb(cur, "kqv_out", il);
} }
@ -4658,7 +4750,7 @@ struct llm_build_context {
cur = llm_build_kqv(ctx0, hparams, kv_self, cur = llm_build_kqv(ctx0, hparams, kv_self,
model.layers[il].wo, model.layers[il].bo, model.layers[il].wo, model.layers[il].bo,
Qcur, KQ_scale, KQ_mask, n_ctx, n_tokens, n_kv, -1.0f, cb, il); Qcur, KQ_scale, KQ_mask, n_ctx, n_tokens, n_kv, -1.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il);
cb(cur, "kqv_out", il); cb(cur, "kqv_out", il);
} }
@ -4867,7 +4959,7 @@ struct llm_build_context {
// TODO: not tested, could be broken // TODO: not tested, could be broken
cur = llm_build_kqv(ctx0, hparams, kv_self, cur = llm_build_kqv(ctx0, hparams, kv_self,
model.layers[il].wo, model.layers[il].bo, model.layers[il].wo, model.layers[il].bo,
Q, KQ_scale, KQ_mask, n_ctx, n_tokens, n_kv, -1.0f, cb, il); Q, KQ_scale, KQ_mask, n_ctx, n_tokens, n_kv, -1.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il);
cb(cur, "kqv_out", il); cb(cur, "kqv_out", il);
} }
@ -4958,7 +5050,7 @@ struct llm_build_context {
cur = llm_build_kqv(ctx0, hparams, kv_self, cur = llm_build_kqv(ctx0, hparams, kv_self,
model.layers[il].wo, NULL, model.layers[il].wo, NULL,
Qcur, KQ_scale, KQ_mask, n_ctx, n_tokens, n_kv, 8.0f, cb, il); Qcur, KQ_scale, KQ_mask, n_ctx, n_tokens, n_kv, 8.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il);
cb(cur, "kqv_out", il); cb(cur, "kqv_out", il);
} }
@ -5055,7 +5147,7 @@ struct llm_build_context {
cur = llm_build_kqv(ctx0, hparams, kv_self, cur = llm_build_kqv(ctx0, hparams, kv_self,
model.layers[il].wo, model.layers[il].bo, model.layers[il].wo, model.layers[il].bo,
Qcur, KQ_scale, KQ_mask, n_ctx, n_tokens, n_kv, 8.0f, cb, il); Qcur, KQ_scale, KQ_mask, n_ctx, n_tokens, n_kv, 8.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il);
cb(cur, "kqv_out", il); cb(cur, "kqv_out", il);
} }
@ -5149,7 +5241,7 @@ struct llm_build_context {
cur = llm_build_kqv(ctx0, hparams, kv_self, cur = llm_build_kqv(ctx0, hparams, kv_self,
model.layers[il].wo, NULL, model.layers[il].wo, NULL,
Qcur, KQ_scale, KQ_mask, n_ctx, n_tokens, n_kv, hparams.f_max_alibi_bias, cb, il); Qcur, KQ_scale, KQ_mask, n_ctx, n_tokens, n_kv, hparams.f_max_alibi_bias, 1.0f/sqrtf(float(n_embd_head)), cb, il);
cb(cur, "kqv_out", il); cb(cur, "kqv_out", il);
} }
@ -5262,7 +5354,7 @@ struct llm_build_context {
cur = llm_build_kqv(ctx0, hparams, kv_self, cur = llm_build_kqv(ctx0, hparams, kv_self,
model.layers[il].wo, NULL, model.layers[il].wo, NULL,
Qcur, KQ_scale, KQ_mask, n_ctx, n_tokens, n_kv, -1.0f, cb, il); Qcur, KQ_scale, KQ_mask, n_ctx, n_tokens, n_kv, -1.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il);
cb(cur, "kqv_out", il); cb(cur, "kqv_out", il);
} }
@ -5319,15 +5411,15 @@ struct llm_build_context {
cb(inpL, "inp_embd", -1); cb(inpL, "inp_embd", -1);
// inp_pos - contains the positions // inp_pos - contains the positions
struct ggml_tensor * inp_pos= ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens); struct ggml_tensor * inp_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens);
cb(inp_pos, "inp_pos", -1); cb(inp_pos, "inp_pos", -1);
// KQ_scale // KQ_scale
struct ggml_tensor * KQ_scale= ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1); struct ggml_tensor * KQ_scale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1);
cb(KQ_scale, "KQ_scale", -1); cb(KQ_scale, "KQ_scale", -1);
// KQ_mask (mask for 1 head, it will be broadcasted to all heads) // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
struct ggml_tensor * KQ_mask= ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, n_tokens, 1); struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, n_tokens, 1);
cb(KQ_mask, "KQ_mask", -1); cb(KQ_mask, "KQ_mask", -1);
// shift the entire K-cache if needed // shift the entire K-cache if needed
@ -5379,7 +5471,7 @@ struct llm_build_context {
cur = llm_build_kqv(ctx0, hparams, kv_self, cur = llm_build_kqv(ctx0, hparams, kv_self,
model.layers[il].wo, NULL, model.layers[il].wo, NULL,
Qcur, KQ_scale, KQ_mask, n_ctx, n_tokens, n_kv, -1.0f, cb, il); Qcur, KQ_scale, KQ_mask, n_ctx, n_tokens, n_kv, -1.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il);
cb(cur, "kqv_out", il); cb(cur, "kqv_out", il);
} }
@ -5421,6 +5513,122 @@ struct llm_build_context {
ggml_build_forward_expand(gf, cur); ggml_build_forward_expand(gf, cur);
return gf;
}
struct ggml_cgraph * build_phi2() {
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
struct ggml_tensor * cur;
struct ggml_tensor * attn_norm_output;
struct ggml_tensor * ffn_output;
struct ggml_tensor * inpL;
inpL = llm_build_inp_embd(ctx0, hparams, batch, model.tok_embd, cb);
cb(inpL, "inp_embd", -1);
// inp_pos - contains the positions
struct ggml_tensor * inp_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens);
cb(inp_pos, "inp_pos", -1);
// Q_scale
struct ggml_tensor * Q_scale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1);
cb(Q_scale, "Q_scale", -1);
// KQ_scale
struct ggml_tensor * KQ_scale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1);
cb(KQ_scale, "KQ_scale", -1);
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, n_tokens, 1);
cb(KQ_mask, "KQ_mask", -1);
// shift the entire K-cache if needed
if (do_rope_shift) {
llm_build_k_shift(ctx0, hparams, cparams, kv_self, gf, LLM_ROPE_NEOX, n_ctx, n_embd_head, freq_base, freq_scale, cb);
}
for (int il = 0; il < n_layer; ++il) {
attn_norm_output = llm_build_norm(ctx0, inpL, hparams,
model.layers[il].attn_norm,
model.layers[il].attn_norm_b,
LLM_NORM, cb, il);
cb(attn_norm_output, "attn_norm", il);
// self-attention
{
cur = ggml_mul_mat(ctx0, model.layers[il].wqkv, attn_norm_output);
cb(cur, "wqkv", il);
cur = ggml_add(ctx0, cur, model.layers[il].bqkv);
cb(cur, "bqkv", il);
struct ggml_tensor * Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd)));
struct ggml_tensor * Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd)));
struct ggml_tensor * Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)));
cb(Qcur, "Qcur", il);
cb(Kcur, "Kcur", il);
cb(Vcur, "Vcur", il);
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
Qcur = ggml_rope_custom(
ctx0, Qcur, inp_pos, hparams.n_rot, 2, 0, n_orig_ctx,
freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
);
cb(Qcur, "Qcur", il);
Qcur = ggml_scale(ctx0, Qcur, Q_scale);
cb(Qcur, "Qcur", il);
Kcur = ggml_rope_custom(
ctx0, Kcur, inp_pos, hparams.n_rot, 2, 0, n_orig_ctx,
freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
);
cb(Kcur, "Kcur", il);
llm_build_kv_store(ctx0, hparams, kv_self, gf, Kcur, Vcur, n_ctx, n_tokens, kv_head, cb, il);
cur = llm_build_kqv(ctx0, hparams, kv_self,
model.layers[il].wo, model.layers[il].bo,
Qcur, KQ_scale, KQ_mask, n_ctx, n_tokens, n_kv, -1.0f, 1.0f, cb, il);
cb(cur, "kqv_out", il);
}
// FF
{
ffn_output = llm_build_ffn(ctx0, attn_norm_output,
model.layers[il].ffn_up, model.layers[il].ffn_up_b,
NULL, NULL,
model.layers[il].ffn_down, model.layers[il].ffn_down_b,
LLM_FFN_GELU, LLM_FFN_SEQ, cb, il);
cb(ffn_output, "ffn_out", il);
}
cur = ggml_add(ctx0, cur, ffn_output);
cb(cur, "l_out", il);
cur = ggml_add(ctx0, cur, inpL);
cb(cur, "l_out", il);
inpL = cur;
}
cur = llm_build_norm(ctx0, inpL, hparams,
model.output_norm,
model.output_norm_b,
LLM_NORM, cb, -1);
cb(cur, "result_norm", -1);
cur = ggml_mul_mat(ctx0, model.output, cur);
cb(cur, "result_output_no_bias", -1);
cur = ggml_add(ctx0, cur, model.output_b);
cb(cur, "result_output", -1);
ggml_build_forward_expand(gf, cur);
return gf; return gf;
} }
}; };
@ -5436,7 +5644,7 @@ enum llm_offload_func_e {
OFFLOAD_FUNC_FRC, // force offload OFFLOAD_FUNC_FRC, // force offload
OFFLOAD_FUNC_KQV, OFFLOAD_FUNC_KQV,
OFFLOAD_FUNC_NR, OFFLOAD_FUNC_NR,
OFFLOAD_FUNC_EMB, OFFLOAD_FUNC_EMB, // embeddings
OFFLOAD_FUNC_OUT, OFFLOAD_FUNC_OUT,
}; };
@ -5521,6 +5729,7 @@ static const std::unordered_map<const char *, llm_offload_func_e> k_offload_map
{ "pos_embd", OFFLOAD_FUNC_NR }, { "pos_embd", OFFLOAD_FUNC_NR },
{ "inp_pos", OFFLOAD_FUNC_FRC }, // this is often used for KQ ops (e.g. rope) { "inp_pos", OFFLOAD_FUNC_FRC }, // this is often used for KQ ops (e.g. rope)
{ "Q_scale", OFFLOAD_FUNC_FRC },
{ "KQ_scale", OFFLOAD_FUNC_FRC }, { "KQ_scale", OFFLOAD_FUNC_FRC },
{ "KQ_mask", OFFLOAD_FUNC_FRC }, { "KQ_mask", OFFLOAD_FUNC_FRC },
{ "K_shift", OFFLOAD_FUNC_FRC }, { "K_shift", OFFLOAD_FUNC_FRC },
@ -5605,6 +5814,7 @@ static const std::unordered_map<const char *, llm_offload_func_e> k_offload_map
{ "l_out", OFFLOAD_FUNC }, { "l_out", OFFLOAD_FUNC },
{ "result_norm", OFFLOAD_FUNC_EMB }, { "result_norm", OFFLOAD_FUNC_EMB },
{ "result_output_no_bias", OFFLOAD_FUNC_EMB },
{ "result_output", OFFLOAD_FUNC_OUT }, { "result_output", OFFLOAD_FUNC_OUT },
}; };
@ -5622,6 +5832,7 @@ static struct ggml_cgraph * llama_build_graph(
bool alloc_inp_tokens = false; bool alloc_inp_tokens = false;
bool alloc_inp_embd = false; bool alloc_inp_embd = false;
bool alloc_inp_pos = false; bool alloc_inp_pos = false;
bool alloc_inp_Q_scale = false;
bool alloc_inp_KQ_scale = false; bool alloc_inp_KQ_scale = false;
bool alloc_inp_KQ_mask = false; bool alloc_inp_KQ_mask = false;
bool alloc_inp_K_shift = false; bool alloc_inp_K_shift = false;
@ -5689,7 +5900,7 @@ static struct ggml_cgraph * llama_build_graph(
alloc_inp_pos = true; alloc_inp_pos = true;
} }
if (!alloc_inp_KQ_scale && strcmp(name, "KQ_scale") == 0) { if (!alloc_inp_Q_scale && strcmp(name, "Q_scale") == 0) {
ggml_allocr_alloc(lctx.alloc, cur); ggml_allocr_alloc(lctx.alloc, cur);
if (!ggml_allocr_is_measure(lctx.alloc)) { if (!ggml_allocr_is_measure(lctx.alloc)) {
@ -5697,6 +5908,23 @@ static struct ggml_cgraph * llama_build_graph(
ggml_set_f32(cur, 1.0f/sqrtf(float(n_embd_head))); ggml_set_f32(cur, 1.0f/sqrtf(float(n_embd_head)));
} }
alloc_inp_Q_scale = true;
}
if (!alloc_inp_KQ_scale && strcmp(name, "KQ_scale") == 0) {
ggml_allocr_alloc(lctx.alloc, cur);
if (!ggml_allocr_is_measure(lctx.alloc)) {
const int64_t n_embd_head = model.hparams.n_embd_head();
if (model.arch == LLM_ARCH_PHI2) {
// with phi2, we scale the Q to avoid precision issues
// ref: https://github.com/ml-explore/mlx-examples/blob/08e862336ade809bc37d1035f94b359e7d1a5152/phi2/phi2.py#L64-L66
ggml_set_f32(cur, 1.0f);
} else {
ggml_set_f32(cur, 1.0f/sqrtf(float(n_embd_head)));
}
}
alloc_inp_KQ_scale = true; alloc_inp_KQ_scale = true;
} }
@ -5921,6 +6149,10 @@ static struct ggml_cgraph * llama_build_graph(
{ {
result = llm.build_qwen(); result = llm.build_qwen();
} break; } break;
case LLM_ARCH_PHI2:
{
result = llm.build_phi2();
} break;
default: default:
GGML_ASSERT(false); GGML_ASSERT(false);
} }
@ -6054,12 +6286,16 @@ static int llama_decode_internal(
ggml_allocr_alloc_graph(lctx.alloc, gf); ggml_allocr_alloc_graph(lctx.alloc, gf);
struct ggml_tensor * res = gf->nodes[gf->n_nodes - 1]; // the output is always the last tensor in the graph
struct ggml_tensor * res = gf->nodes[gf->n_nodes - 1];
GGML_ASSERT(strcmp(res->name, "result_output") == 0);
// the embeddings could be the second to last tensor, or the third to last tensor
struct ggml_tensor * embeddings = gf->nodes[gf->n_nodes - 2]; struct ggml_tensor * embeddings = gf->nodes[gf->n_nodes - 2];
if (strcmp(embeddings->name, "result_norm") != 0) {
GGML_ASSERT(strcmp(res->name, "result_output") == 0); embeddings = gf->nodes[gf->n_nodes - 3];
GGML_ASSERT(strcmp(embeddings->name, "result_norm") == 0); GGML_ASSERT(strcmp(embeddings->name, "result_norm") == 0);
}
#ifdef GGML_USE_CUBLAS #ifdef GGML_USE_CUBLAS
for (int i = 0; i < gf->n_leafs; i++) { for (int i = 0; i < gf->n_leafs; i++) {

View file

@ -1555,6 +1555,7 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
test_cases.emplace_back(new test_rope(type, { 64, 8, 10, 1}, 64, 2, 512)); // neox (falcon 40B) test_cases.emplace_back(new test_rope(type, { 64, 8, 10, 1}, 64, 2, 512)); // neox (falcon 40B)
test_cases.emplace_back(new test_rope(type, { 64, 128, 10, 1}, 64, 2, 512)); // neox (falcon 40B) test_cases.emplace_back(new test_rope(type, { 64, 128, 10, 1}, 64, 2, 512)); // neox (falcon 40B)
test_cases.emplace_back(new test_rope(type, { 80, 32, 10, 1}, 20, 2, 512)); // neox (stablelm) test_cases.emplace_back(new test_rope(type, { 80, 32, 10, 1}, 20, 2, 512)); // neox (stablelm)
test_cases.emplace_back(new test_rope(type, { 80, 32, 10, 1}, 32, 2, 512)); // neox (phi-2)
} }
test_cases.emplace_back(new test_alibi()); test_cases.emplace_back(new test_alibi());