From 3b9ea655d4cef780b2330071ec7838623351f3ae Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Fri, 27 Oct 2023 18:13:54 +0300 Subject: [PATCH 01/13] cuda : use CUBLAS_COMPUTE_32F to speed-up and avoid dst cpy --- ggml-cuda.cu | 57 +++++++++++++++++++--------------------------------- 1 file changed, 21 insertions(+), 36 deletions(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 1ba951f68..e03e500d7 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -6385,27 +6385,19 @@ inline void ggml_cuda_op_mul_mat_cublas( } const half * src1_ptr = src1->type == GGML_TYPE_F16 ? (const half *) src1_ddq_i : src1_as_f16; - size_t dst_as = 0; - half * dst_f16 = (half *) ggml_cuda_pool_malloc(row_diff*src1_ncols * sizeof(half), &dst_as); - - const half alpha_f16 = 1.0f; - const half beta_f16 = 0.0f; + const float alpha = 1.0f; + const float beta = 0.0f; CUBLAS_CHECK(cublasSetStream(g_cublas_handles[id], stream)); CUBLAS_CHECK( cublasGemmEx(g_cublas_handles[id], CUBLAS_OP_T, CUBLAS_OP_N, row_diff, src1_ncols, ne10, - &alpha_f16, src0_ptr, CUDA_R_16F, ne00, - src1_ptr, CUDA_R_16F, ne10, - &beta_f16, dst_f16, CUDA_R_16F, ldc, - CUBLAS_COMPUTE_16F, + &alpha, src0_ptr, CUDA_R_16F, ne00, + src1_ptr, CUDA_R_16F, ne10, + &beta, dst_dd_i, CUDA_R_32F, ldc, + CUBLAS_COMPUTE_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP)); - const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16); - to_fp32_cuda(dst_f16, dst_dd_i, row_diff*src1_ncols, stream); - - ggml_cuda_pool_free(dst_f16, dst_as); - if (src0_as != 0) { ggml_cuda_pool_free(src0_as_f16, src0_as); } @@ -7189,9 +7181,6 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const half * src1_as_f16 = (half *) ggml_cuda_pool_malloc(ne1 * sizeof(half), &src1_as); to_fp16_cuda(src1_ddf, src1_as_f16, ne1, main_stream); - size_t dst_as = 0; - half * dst_f16 = (half *) ggml_cuda_pool_malloc(ne * sizeof(half), &dst_as); - GGML_ASSERT(ne12 % ne02 == 0); GGML_ASSERT(ne13 % ne03 == 0); @@ -7199,8 +7188,8 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const const int64_t r2 = ne12/ne02; const int64_t r3 = ne13/ne03; - const half alpha_f16 = 1.0f; - const half beta_f16 = 0.0f; + const float alpha = 1.0f; + const float beta = 0.0f; #if 0 // use cublasGemmEx @@ -7213,10 +7202,10 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const CUBLAS_CHECK( cublasGemmEx(g_cublas_handles[id], CUBLAS_OP_T, CUBLAS_OP_N, ne01, ne11, ne10, - &alpha_f16, (const char *) src0_as_f16 + i02*src0->nb[2] + i03*src0->nb[3] , CUDA_R_16F, nb01/sizeof(half), - (const char *) src1_as_f16 + i12*src1->nb[2]/2 + i13*src1->nb[3]/2, CUDA_R_16F, nb11/sizeof(float), - &beta_f16, ( char *) dst_f16 + i12* dst->nb[2]/2 + i13* dst->nb[3]/2, CUDA_R_16F, ne01, - CUBLAS_COMPUTE_16F, + &alpha, (const char *) src0_as_f16 + i02*src0->nb[2] + i03*src0->nb[3] , CUDA_R_16F, nb01/sizeof(half), + (const char *) src1_as_f16 + i12*src1->nb[2]/2 + i13*src1->nb[3]/2, CUDA_R_16F, nb11/sizeof(float), + &beta, ( char *) dst_ddf + i12* dst->nb[2] + i13* dst->nb[3] , CUDA_R_32F, ne01, + CUBLAS_COMPUTE_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP)); } } @@ -7228,11 +7217,11 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const CUBLAS_CHECK( cublasGemmStridedBatchedEx(g_cublas_handles[id], CUBLAS_OP_T, CUBLAS_OP_N, ne01, ne11, ne10, - &alpha_f16, (const char *) src0_as_f16, CUDA_R_16F, nb01/sizeof(half), src0->nb[2]/sizeof(half), // strideA - (const char *) src1_as_f16, CUDA_R_16F, nb11/sizeof(float), src1->nb[2]/sizeof(float), // strideB - &beta_f16, ( char *) dst_f16, CUDA_R_16F, ne01, dst->nb[2]/sizeof(float), // strideC + &alpha, (const char *) src0_as_f16, CUDA_R_16F, nb01/sizeof(half), src0->nb[2]/sizeof(half), // strideA + (const char *) src1_as_f16, CUDA_R_16F, nb11/sizeof(float), src1->nb[2]/sizeof(float), // strideB + &beta, ( char *) dst_ddf, CUDA_R_32F, ne01, dst->nb[2]/sizeof(float), // strideC ne12*ne13, - CUBLAS_COMPUTE_16F, + CUBLAS_COMPUTE_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP)); } else { // use cublasGemmBatchedEx @@ -7249,7 +7238,7 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const ptrs[0*ne23 + i12 + i13*ne12] = (char *) src0_as_f16 + i02*src0->nb[2] + i03*src0->nb[3]; ptrs[1*ne23 + i12 + i13*ne12] = (char *) src1_as_f16 + i12*src1->nb[2]/2 + i13*src1->nb[3]/2; - ptrs[2*ne23 + i12 + i13*ne12] = (char *) dst_f16 + i12* dst->nb[2]/2 + i13* dst->nb[3]/2; + ptrs[2*ne23 + i12 + i13*ne12] = (char *) dst_ddf + i12* dst->nb[2] + i13* dst->nb[3] ; } } @@ -7269,11 +7258,11 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const CUBLAS_CHECK( cublasGemmBatchedEx(g_cublas_handles[id], CUBLAS_OP_T, CUBLAS_OP_N, ne01, ne11, ne10, - &alpha_f16, (const void **) (ptrs_as + 0*ne23), CUDA_R_16F, nb01/sizeof(half), - (const void **) (ptrs_as + 1*ne23), CUDA_R_16F, nb11/sizeof(float), - &beta_f16, ( void **) (ptrs_as + 2*ne23), CUDA_R_16F, ne01, + &alpha, (const void **) (ptrs_as + 0*ne23), CUDA_R_16F, nb01/sizeof(half), + (const void **) (ptrs_as + 1*ne23), CUDA_R_16F, nb11/sizeof(float), + &beta, ( void **) (ptrs_as + 2*ne23), CUDA_R_32F, ne01, ne23, - CUBLAS_COMPUTE_16F, + CUBLAS_COMPUTE_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP)); // free device memory for pointers @@ -7282,11 +7271,7 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const } #endif - const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16); - to_fp32_cuda(dst_f16, dst_ddf, ne, main_stream); - ggml_cuda_pool_free(src1_as_f16, src1_as); - ggml_cuda_pool_free(dst_f16, dst_as); } static void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { From 0f2498f25d7e278f075d060e8e77e68dacf4e90c Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Fri, 27 Oct 2023 20:15:21 +0300 Subject: [PATCH 02/13] cuda : use CUBLAS_COMPUTE_16F for non-attention ops --- ggml-cuda.cu | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index e03e500d7..d16b8f9c5 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -6385,8 +6385,11 @@ inline void ggml_cuda_op_mul_mat_cublas( } const half * src1_ptr = src1->type == GGML_TYPE_F16 ? (const half *) src1_ddq_i : src1_as_f16; - const float alpha = 1.0f; - const float beta = 0.0f; + size_t dst_as = 0; + half * dst_f16 = (half *) ggml_cuda_pool_malloc(row_diff*src1_ncols * sizeof(half), &dst_as); + + const half alpha = 1.0f; + const half beta = 0.0f; CUBLAS_CHECK(cublasSetStream(g_cublas_handles[id], stream)); CUBLAS_CHECK( @@ -6394,10 +6397,15 @@ inline void ggml_cuda_op_mul_mat_cublas( row_diff, src1_ncols, ne10, &alpha, src0_ptr, CUDA_R_16F, ne00, src1_ptr, CUDA_R_16F, ne10, - &beta, dst_dd_i, CUDA_R_32F, ldc, - CUBLAS_COMPUTE_32F, + &beta, dst_f16, CUDA_R_16F, ldc, + CUBLAS_COMPUTE_16F, CUBLAS_GEMM_DEFAULT_TENSOR_OP)); + const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16); + to_fp32_cuda(dst_f16, dst_dd_i, row_diff*src1_ncols, stream); + + ggml_cuda_pool_free(dst_f16, dst_as); + if (src0_as != 0) { ggml_cuda_pool_free(src0_as_f16, src0_as); } From e374227221a5c72ce6fd12bf6bc9db8c72101546 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sat, 28 Oct 2023 12:20:08 +0300 Subject: [PATCH 03/13] Revert "cuda : use CUBLAS_COMPUTE_16F for non-attention ops" This reverts commit 0f2498f25d7e278f075d060e8e77e68dacf4e90c. --- ggml-cuda.cu | 16 ++++------------ 1 file changed, 4 insertions(+), 12 deletions(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index d16b8f9c5..e03e500d7 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -6385,11 +6385,8 @@ inline void ggml_cuda_op_mul_mat_cublas( } const half * src1_ptr = src1->type == GGML_TYPE_F16 ? (const half *) src1_ddq_i : src1_as_f16; - size_t dst_as = 0; - half * dst_f16 = (half *) ggml_cuda_pool_malloc(row_diff*src1_ncols * sizeof(half), &dst_as); - - const half alpha = 1.0f; - const half beta = 0.0f; + const float alpha = 1.0f; + const float beta = 0.0f; CUBLAS_CHECK(cublasSetStream(g_cublas_handles[id], stream)); CUBLAS_CHECK( @@ -6397,15 +6394,10 @@ inline void ggml_cuda_op_mul_mat_cublas( row_diff, src1_ncols, ne10, &alpha, src0_ptr, CUDA_R_16F, ne00, src1_ptr, CUDA_R_16F, ne10, - &beta, dst_f16, CUDA_R_16F, ldc, - CUBLAS_COMPUTE_16F, + &beta, dst_dd_i, CUDA_R_32F, ldc, + CUBLAS_COMPUTE_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP)); - const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16); - to_fp32_cuda(dst_f16, dst_dd_i, row_diff*src1_ncols, stream); - - ggml_cuda_pool_free(dst_f16, dst_as); - if (src0_as != 0) { ggml_cuda_pool_free(src0_as_f16, src0_as); } From 12cc80cb8975aea3bc9f39d3c9b84f7001ab94c5 Mon Sep 17 00:00:00 2001 From: Ebey Abraham Date: Fri, 15 Dec 2023 20:56:57 +0000 Subject: [PATCH 04/13] phi2 implementation --- convert-hf-to-gguf.py | 19 ++++ gguf-py/gguf/constants.py | 13 +++ gguf-py/gguf/tensor_mapping.py | 8 ++ llama.cpp | 187 ++++++++++++++++++++++++++++++++- 4 files changed, 226 insertions(+), 1 deletion(-) diff --git a/convert-hf-to-gguf.py b/convert-hf-to-gguf.py index e46a7813a..b56be8448 100755 --- a/convert-hf-to-gguf.py +++ b/convert-hf-to-gguf.py @@ -182,6 +182,8 @@ class Model: return QwenModel if model_architecture == "MixtralForCausalLM": return MixtralModel + if model_architecture == "PhiForCausalLM": + return Phi2Model return Model def _is_model_safetensors(self) -> bool: @@ -221,6 +223,8 @@ class Model: return gguf.MODEL_ARCH.QWEN if arch == "MixtralForCausalLM": return gguf.MODEL_ARCH.LLAMA + if arch == "PhiForCausalLM": + return gguf.MODEL_ARCH.PHI2 raise NotImplementedError(f'Architecture "{arch}" not supported!') @@ -980,6 +984,21 @@ class QwenModel(Model): print(f"{new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}") self.gguf_writer.add_tensor(new_name, data) +class Phi2Model(Model): + def set_gguf_parameters(self): + block_count = self.hparams["n_layer"] + + self.gguf_writer.add_name("Phi2") + self.gguf_writer.add_context_length(self.hparams["n_positions"]) + self.gguf_writer.add_embedding_length(self.hparams["n_embd"]) + self.gguf_writer.add_feed_forward_length(4 * self.hparams["n_embd"]) + self.gguf_writer.add_block_count(block_count) + self.gguf_writer.add_head_count(self.hparams["n_head"]) + self.gguf_writer.add_head_count_kv(self.hparams["n_head"]) + self.gguf_writer.add_layer_norm_eps(self.hparams["layer_norm_epsilon"]) + self.gguf_writer.add_rope_dimension_count(self.hparams["rotary_dim"]) + self.gguf_writer.add_file_type(self.ftype) + ###### CONVERSION LOGIC ###### diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 12133882b..390dca049 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -95,6 +95,7 @@ class MODEL_ARCH(IntEnum): BLOOM = auto() STABLELM = auto() QWEN = auto() + PHI2 = auto() class MODEL_TENSOR(IntEnum): @@ -140,6 +141,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = { MODEL_ARCH.BLOOM: "bloom", MODEL_ARCH.STABLELM: "stablelm", MODEL_ARCH.QWEN: "qwen", + MODEL_ARCH.PHI2: "phi2", } TENSOR_NAMES: dict[MODEL_TENSOR, str] = { @@ -350,6 +352,17 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = { MODEL_ARCH.GPT2: [ # TODO ], + MODEL_ARCH.PHI2: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_QKV, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.FFN_NORM, + MODEL_TENSOR.FFN_DOWN, + MODEL_TENSOR.FFN_UP, + ] # TODO } diff --git a/gguf-py/gguf/tensor_mapping.py b/gguf-py/gguf/tensor_mapping.py index 0115ea1c6..6fcbdbc1c 100644 --- a/gguf-py/gguf/tensor_mapping.py +++ b/gguf-py/gguf/tensor_mapping.py @@ -17,6 +17,7 @@ class TensorNameMap: "tok_embeddings", # llama-pth "embeddings.word_embeddings", # bert "language_model.embedding.word_embeddings", # persimmon + "transformer.embd.wte", # phi2 ), # Token type embeddings @@ -41,6 +42,7 @@ class TensorNameMap: "lm_head", # gpt2 mpt falcon llama-hf baichuan qwen "output", # llama-pth bloom "word_embeddings_for_head", # persimmon + "lm_head.linear", # phi2 ), # Output norm @@ -53,6 +55,7 @@ class TensorNameMap: "transformer.norm_f", # mpt "ln_f", # refact bloom qwen "language_model.encoder.final_layernorm", # persimmon + "lm_head.ln", # phi2 ), # Rope frequencies @@ -75,6 +78,7 @@ class TensorNameMap: "encoder.layer.{bid}.attention.output.LayerNorm", # bert "language_model.encoder.layers.{bid}.input_layernorm", # persimmon "model.layers.{bid}.ln1", # yi + "transformer.h.{bid}.ln", # phi2 ), # Attention norm 2 @@ -90,6 +94,7 @@ class TensorNameMap: "transformer.h.{bid}.self_attention.query_key_value", # falcon "h.{bid}.self_attention.query_key_value", # bloom "language_model.encoder.layers.{bid}.self_attention.query_key_value", # persimmon + "transformer.h.{bid}.mixer.Wqkv", # phi2 ), # Attention query @@ -128,6 +133,7 @@ class TensorNameMap: "encoder.layer.{bid}.attention.output.dense", # bert "transformer.h.{bid}.attn.out_proj", # gpt-j "language_model.encoder.layers.{bid}.self_attention.dense", # persimmon + "transformer.h.{bid}.mixer.out_proj", # phi2 ), # Rotary embeddings @@ -167,6 +173,7 @@ class TensorNameMap: "transformer.h.{bid}.mlp.fc_in", # gpt-j "language_model.encoder.layers.{bid}.mlp.dense_h_to_4h", # persimmon "transformer.h.{bid}.mlp.w1", # qwen + "transformer.h.{bid}.mlp.fc1", # phi2 ), MODEL_TENSOR.FFN_UP_EXP: ( @@ -198,6 +205,7 @@ class TensorNameMap: "encoder.layer.{bid}.output.dense", # bert "transformer.h.{bid}.mlp.fc_out", # gpt-j "language_model.encoder.layers.{bid}.mlp.dense_4h_to_h", # persimmon + "transformer.h.{bid}.mlp.fc2", # phi2 ), MODEL_TENSOR.FFN_DOWN_EXP: ( diff --git a/llama.cpp b/llama.cpp index eddb70859..e229ecfe3 100644 --- a/llama.cpp +++ b/llama.cpp @@ -195,6 +195,7 @@ enum llm_arch { LLM_ARCH_BLOOM, LLM_ARCH_STABLELM, LLM_ARCH_QWEN, + LLM_ARCH_PHI2, LLM_ARCH_UNKNOWN, }; @@ -212,6 +213,7 @@ static std::map LLM_ARCH_NAMES = { { LLM_ARCH_BLOOM, "bloom" }, { LLM_ARCH_STABLELM, "stablelm" }, { LLM_ARCH_QWEN, "qwen" }, + { LLM_ARCH_PHI2, "phi2" }, }; enum llm_kv { @@ -550,6 +552,19 @@ static std::map> LLM_TENSOR_NAMES = { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, }, }, + { + LLM_ARCH_PHI2, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + }, + }, { LLM_ARCH_UNKNOWN, @@ -1420,6 +1435,7 @@ struct llama_model { struct ggml_tensor * output_norm; struct ggml_tensor * output_norm_b; struct ggml_tensor * output; + struct ggml_tensor * output_b; std::vector layers; @@ -3625,7 +3641,77 @@ static void llm_load_tensors( } } } break; + case LLM_ARCH_PHI2: + { + // TODO: CPU-only for now + model.tok_embd = ml.create_tensor(ctx, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, GGML_BACKEND_CPU); + + // output + { + ggml_backend_type backend_norm; + ggml_backend_type backend_output; + + if (n_gpu_layers > int(n_layer)) { + backend_norm = llama_backend_offload; + backend_output = llama_backend_offload_split; + } else { + backend_norm = GGML_BACKEND_CPU; + backend_output = GGML_BACKEND_CPU; + } + + model.output_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, backend_norm); + model.output_norm_b = ml.create_tensor(ctx, tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, backend_norm); + model.output = ml.create_tensor(ctx, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, backend_output); + model.output_b = ml.create_tensor(ctx, tn(LLM_TENSOR_OUTPUT, "bias"), {n_vocab}, backend_output); + + if (backend_norm == GGML_BACKEND_GPU) { + vram_weights += ggml_nbytes(model.output_norm); + vram_weights += ggml_nbytes(model.output_norm_b); + } + if (backend_output == GGML_BACKEND_GPU_SPLIT) { + vram_weights += ggml_nbytes(model.output); + vram_weights += ggml_nbytes(model.output_b); + } + } + + const uint32_t n_ff = hparams.n_ff; + + const int i_gpu_start = n_layer - n_gpu_layers; + + model.layers.resize(n_layer); + + for (uint32_t i = 0; i < n_layer; ++i) { + const ggml_backend_type backend = int(i) < i_gpu_start ? GGML_BACKEND_CPU : llama_backend_offload; // NOLINT + const ggml_backend_type backend_split = int(i) < i_gpu_start ? GGML_BACKEND_CPU : llama_backend_offload_split; // NOLINT + + auto & layer = model.layers[i]; + + layer.attn_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, backend); + layer.attn_norm_b = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, backend); + + layer.wqkv = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, backend_split); + layer.bqkv = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, backend); + + layer.wo = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, backend_split); + layer.bo = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, backend); + + layer.ffn_down = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, backend_split); + layer.ffn_down_b = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, backend); + + layer.ffn_up = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, backend_split); + layer.ffn_up_b = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, backend); + + if (backend == GGML_BACKEND_GPU) { + vram_weights += + ggml_nbytes(layer.attn_norm) + ggml_nbytes(layer.attn_norm_b) + + ggml_nbytes(layer.wqkv) + ggml_nbytes(layer.bqkv) + + ggml_nbytes(layer.wo) + ggml_nbytes(layer.bo) + + ggml_nbytes(layer.ffn_up) + ggml_nbytes(layer.ffn_up_b) + + ggml_nbytes(layer.ffn_down) + ggml_nbytes(layer.ffn_down_b); + } + } + } break; default: throw std::runtime_error("unknown architecture"); } @@ -5417,6 +5503,101 @@ struct llm_build_context { ggml_build_forward_expand(gf, cur); + return gf; + } + struct ggml_cgraph * build_phi2() { + struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false); + + struct ggml_tensor * cur; + struct ggml_tensor * attn_norm_output; + struct ggml_tensor * ffn_output; + struct ggml_tensor * inpL; + + inpL = llm_build_inp_embd(ctx0, hparams, batch, model.tok_embd, cb); + cb(inpL, "inp_embd", -1); + + // inp_pos - contains the positions + struct ggml_tensor * inp_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens); + cb(inp_pos, "inp_pos", -1); + + // KQ_scale + struct ggml_tensor * KQ_scale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1); + cb(KQ_scale, "KQ_scale", -1); + + // KQ_mask (mask for 1 head, it will be broadcasted to all heads) + struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, n_tokens, 1); + cb(KQ_mask, "KQ_mask", -1); + + for (int il = 0; il < n_layer; ++il) { + + attn_norm_output = llm_build_norm(ctx0, inpL, hparams, + model.layers[il].attn_norm, + model.layers[il].attn_norm_b, + LLM_NORM, cb, il); + cb(attn_norm_output, "attn_norm", il); + + // self-attention + { + cur = ggml_mul_mat(ctx0, model.layers[il].wqkv, attn_norm_output); + cb(cur, "wqkv", il); + + cur = ggml_add(ctx0, cur, model.layers[il].bqkv); + cb(cur, "bqkv", il); + + struct ggml_tensor * Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd))); + struct ggml_tensor * Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd))); + struct ggml_tensor * Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa))); + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + // RoPE + Qcur = ggml_rope(ctx0, Qcur, inp_pos, 32, 2, 0); + Kcur = ggml_rope(ctx0, Kcur, inp_pos, 32, 2, 0); + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + + llm_build_kv_store(ctx0, hparams, kv_self, gf, Kcur, Vcur, n_ctx, n_tokens, kv_head, cb, il); + + cur = llm_build_kqv(ctx0, hparams, kv_self, + model.layers[il].wo, model.layers[il].bo, + Qcur, KQ_scale, KQ_mask, n_ctx, n_tokens, n_kv, -1.0f, cb, il); + cb(cur, "kqv_out", il); + } + + // FF + { + ffn_output = llm_build_ffn(ctx0, attn_norm_output, + model.layers[il].ffn_up, model.layers[il].ffn_up_b, + NULL, NULL, + model.layers[il].ffn_down, model.layers[il].ffn_down_b, + LLM_FFN_GELU, LLM_FFN_SEQ, cb, il); + cb(ffn_output, "ffn_out", il); + } + + inpL = ggml_add(ctx0, cur, ggml_add_inplace(ctx0, ffn_output, inpL)); + cb(inpL, "l_out", il); + } + + cur = llm_build_norm(ctx0, inpL, hparams, + model.output_norm, + model.output_norm_b, + LLM_NORM, cb, -1); + cb(cur, "result_norm", -1); + + cur = ggml_mul_mat(ctx0, model.output, cur); + cb(cur, "result_output", -1); + + cur = ggml_add(ctx0, cur, model.output_b); + cb(cur, "result_output", -1); + + ggml_build_forward_expand(gf, cur); + return gf; } }; @@ -5917,6 +6098,10 @@ static struct ggml_cgraph * llama_build_graph( { result = llm.build_qwen(); } break; + case LLM_ARCH_PHI2: + { + result = llm.build_phi2(); + } break; default: GGML_ASSERT(false); } @@ -6051,7 +6236,7 @@ static int llama_decode_internal( ggml_allocr_alloc_graph(lctx.alloc, gf); struct ggml_tensor * res = gf->nodes[gf->n_nodes - 1]; - struct ggml_tensor * embeddings = gf->nodes[gf->n_nodes - 2]; + struct ggml_tensor * embeddings = gf->nodes[gf->n_nodes - 3]; GGML_ASSERT(strcmp(res->name, "result_output") == 0); GGML_ASSERT(strcmp(embeddings->name, "result_norm") == 0); From e20765534d3b098afb804c7660c66b5e1d4719cb Mon Sep 17 00:00:00 2001 From: Ebey Abraham Date: Sat, 16 Dec 2023 00:41:06 +0000 Subject: [PATCH 05/13] fix breaking change --- llama.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/llama.cpp b/llama.cpp index e229ecfe3..162692ce8 100644 --- a/llama.cpp +++ b/llama.cpp @@ -5591,7 +5591,7 @@ struct llm_build_context { cb(cur, "result_norm", -1); cur = ggml_mul_mat(ctx0, model.output, cur); - cb(cur, "result_output", -1); + cb(cur, "result_norm", -1); cur = ggml_add(ctx0, cur, model.output_b); cb(cur, "result_output", -1); @@ -6236,7 +6236,7 @@ static int llama_decode_internal( ggml_allocr_alloc_graph(lctx.alloc, gf); struct ggml_tensor * res = gf->nodes[gf->n_nodes - 1]; - struct ggml_tensor * embeddings = gf->nodes[gf->n_nodes - 3]; + struct ggml_tensor * embeddings = gf->nodes[gf->n_nodes - 2]; GGML_ASSERT(strcmp(res->name, "result_output") == 0); GGML_ASSERT(strcmp(embeddings->name, "result_norm") == 0); From a2a3d2c8d7ce8df3c8e81f083d6d972c364e6b67 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sat, 16 Dec 2023 10:46:18 +0200 Subject: [PATCH 06/13] phi-2 : various fixes --- convert-hf-to-gguf.py | 4 +- ggml-cuda.cu | 11 ++++- ggml.c | 4 ++ llama.cpp | 83 ++++++++++++++++++++++---------------- tests/test-backend-ops.cpp | 1 + 5 files changed, 66 insertions(+), 37 deletions(-) diff --git a/convert-hf-to-gguf.py b/convert-hf-to-gguf.py index b56be8448..aed84f1b2 100755 --- a/convert-hf-to-gguf.py +++ b/convert-hf-to-gguf.py @@ -249,7 +249,7 @@ class Model: toktypes.append(gguf.TokenType.USER_DEFINED) elif reverse_vocab[i] in added_vocab: tokens.append(reverse_vocab[i]) - if tokenizer.added_tokens_decoder[i].special: + if hasattr(tokenizer, "added_tokens_decoder") and tokenizer.added_tokens_decoder[i].special: toktypes.append(gguf.TokenType.CONTROL) else: toktypes.append(gguf.TokenType.USER_DEFINED) @@ -998,7 +998,7 @@ class Phi2Model(Model): self.gguf_writer.add_layer_norm_eps(self.hparams["layer_norm_epsilon"]) self.gguf_writer.add_rope_dimension_count(self.hparams["rotary_dim"]) self.gguf_writer.add_file_type(self.ftype) - + ###### CONVERSION LOGIC ###### diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 0a63c1ecf..b2e086a19 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -4998,7 +4998,16 @@ static __global__ void rope_neox( const int ib = col / n_dims; const int ic = col % n_dims; - const int i = row*ncols + ib*n_dims + ic/2; + // IMPORTANT: consider the case ncols == 80 and n_dims == 32 (phi-2) + // I don't know what we are supposed to compute, because the row is not divisible by n_dims + // this check matches the CPU code, but it is likely wrong as well + // I can't understand the Python code, so if you know what to do here, please fix it + // ref: https://github.com/ml-explore/mlx/blob/dc2edc762c797e3b8de50b1dad4dc0a131691033/benchmarks/python/llama_jax_bench.py#L11-L26 + if (ncols % n_dims != 0 && ib == ncols/n_dims) { + return; + } + + const int i = row*ncols + ib*n_dims + ic/2; const int i2 = row/p_delta_rows; float cur_rot = inv_ndims * ic - ib; diff --git a/ggml.c b/ggml.c index 1feb7ead3..772be5616 100644 --- a/ggml.c +++ b/ggml.c @@ -9168,6 +9168,8 @@ static void ggml_compute_forward_norm_f32( float eps; memcpy(&eps, dst->op_params, sizeof(float)); + GGML_ASSERT(eps > 0.0f); + // TODO: optimize for (int64_t i03 = 0; i03 < ne03; i03++) { for (int64_t i02 = 0; i02 < ne02; i02++) { @@ -9237,6 +9239,8 @@ static void ggml_compute_forward_rms_norm_f32( float eps; memcpy(&eps, dst->op_params, sizeof(float)); + GGML_ASSERT(eps > 0.0f); + // TODO: optimize for (int64_t i03 = 0; i03 < ne03; i03++) { for (int64_t i02 = 0; i02 < ne02; i02++) { diff --git a/llama.cpp b/llama.cpp index 162692ce8..b4e62341b 100644 --- a/llama.cpp +++ b/llama.cpp @@ -2998,7 +2998,7 @@ static void llm_load_tensors( (void) main_gpu; - enum ggml_backend_type llama_backend_offload = GGML_BACKEND_CPU; + enum ggml_backend_type llama_backend_offload = GGML_BACKEND_CPU; enum ggml_backend_type llama_backend_offload_split = GGML_BACKEND_CPU; #ifdef GGML_USE_CUBLAS @@ -3643,9 +3643,7 @@ static void llm_load_tensors( } break; case LLM_ARCH_PHI2: { - // TODO: CPU-only for now - - model.tok_embd = ml.create_tensor(ctx, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, GGML_BACKEND_CPU); + model.tok_embd = ml.create_tensor(ctx, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, GGML_BACKEND_CPU); // output { @@ -3654,7 +3652,7 @@ static void llm_load_tensors( if (n_gpu_layers > int(n_layer)) { backend_norm = llama_backend_offload; - backend_output = llama_backend_offload_split; + backend_output = llama_backend_offload; } else { backend_norm = GGML_BACKEND_CPU; backend_output = GGML_BACKEND_CPU; @@ -3663,13 +3661,11 @@ static void llm_load_tensors( model.output_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, backend_norm); model.output_norm_b = ml.create_tensor(ctx, tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, backend_norm); model.output = ml.create_tensor(ctx, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, backend_output); - model.output_b = ml.create_tensor(ctx, tn(LLM_TENSOR_OUTPUT, "bias"), {n_vocab}, backend_output); + model.output_b = ml.create_tensor(ctx, tn(LLM_TENSOR_OUTPUT, "bias"), {n_vocab}, backend_output); if (backend_norm == GGML_BACKEND_GPU) { vram_weights += ggml_nbytes(model.output_norm); vram_weights += ggml_nbytes(model.output_norm_b); - } - if (backend_output == GGML_BACKEND_GPU_SPLIT) { vram_weights += ggml_nbytes(model.output); vram_weights += ggml_nbytes(model.output_b); } @@ -3687,20 +3683,20 @@ static void llm_load_tensors( auto & layer = model.layers[i]; - layer.attn_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, backend); - layer.attn_norm_b = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, backend); + layer.attn_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, backend); + layer.attn_norm_b = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, backend); layer.wqkv = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, backend_split); layer.bqkv = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, backend); - layer.wo = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, backend_split); - layer.bo = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, backend); + layer.wo = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, backend_split); + layer.bo = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, backend); layer.ffn_down = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, backend_split); layer.ffn_down_b = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, backend); - layer.ffn_up = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, backend_split); - layer.ffn_up_b = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, backend); + layer.ffn_up = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, backend_split); + layer.ffn_up_b = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, backend); if (backend == GGML_BACKEND_GPU) { vram_weights += @@ -5401,15 +5397,15 @@ struct llm_build_context { cb(inpL, "inp_embd", -1); // inp_pos - contains the positions - struct ggml_tensor * inp_pos= ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens); + struct ggml_tensor * inp_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens); cb(inp_pos, "inp_pos", -1); // KQ_scale - struct ggml_tensor * KQ_scale= ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1); + struct ggml_tensor * KQ_scale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1); cb(KQ_scale, "KQ_scale", -1); // KQ_mask (mask for 1 head, it will be broadcasted to all heads) - struct ggml_tensor * KQ_mask= ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, n_tokens, 1); + struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, n_tokens, 1); cb(KQ_mask, "KQ_mask", -1); // shift the entire K-cache if needed @@ -5528,8 +5524,12 @@ struct llm_build_context { struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, n_tokens, 1); cb(KQ_mask, "KQ_mask", -1); - for (int il = 0; il < n_layer; ++il) { + // shift the entire K-cache if needed + if (do_rope_shift) { + llm_build_k_shift(ctx0, hparams, cparams, kv_self, gf, LLM_ROPE_NEOX, n_ctx, n_embd_head, freq_base, freq_scale, cb); + } + for (int il = 0; il < n_layer; ++il) { attn_norm_output = llm_build_norm(ctx0, inpL, hparams, model.layers[il].attn_norm, model.layers[il].attn_norm_b, @@ -5552,14 +5552,19 @@ struct llm_build_context { cb(Kcur, "Kcur", il); cb(Vcur, "Vcur", il); - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + + Qcur = ggml_rope_custom( + ctx0, Qcur, inp_pos, hparams.n_rot, 2, 0, n_orig_ctx, + freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow + ); cb(Qcur, "Qcur", il); - cb(Kcur, "Kcur", il); - // RoPE - Qcur = ggml_rope(ctx0, Qcur, inp_pos, 32, 2, 0); - Kcur = ggml_rope(ctx0, Kcur, inp_pos, 32, 2, 0); - cb(Qcur, "Qcur", il); + + Kcur = ggml_rope_custom( + ctx0, Kcur, inp_pos, hparams.n_rot, 2, 0, n_orig_ctx, + freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow + ); cb(Kcur, "Kcur", il); llm_build_kv_store(ctx0, hparams, kv_self, gf, Kcur, Vcur, n_ctx, n_tokens, kv_head, cb, il); @@ -5580,8 +5585,13 @@ struct llm_build_context { cb(ffn_output, "ffn_out", il); } - inpL = ggml_add(ctx0, cur, ggml_add_inplace(ctx0, ffn_output, inpL)); - cb(inpL, "l_out", il); + cur = ggml_add(ctx0, cur, ffn_output); + cb(cur, "l_out", il); + + cur = ggml_add(ctx0, cur, inpL); + cb(cur, "l_out", il); + + inpL = cur; } cur = llm_build_norm(ctx0, inpL, hparams, @@ -5589,9 +5599,9 @@ struct llm_build_context { model.output_norm_b, LLM_NORM, cb, -1); cb(cur, "result_norm", -1); - + cur = ggml_mul_mat(ctx0, model.output, cur); - cb(cur, "result_norm", -1); + cb(cur, "result_output_no_bias", -1); cur = ggml_add(ctx0, cur, model.output_b); cb(cur, "result_output", -1); @@ -5613,7 +5623,7 @@ enum llm_offload_func_e { OFFLOAD_FUNC_FRC, // force offload OFFLOAD_FUNC_KQV, OFFLOAD_FUNC_NR, - OFFLOAD_FUNC_EMB, + OFFLOAD_FUNC_EMB, // embeddings OFFLOAD_FUNC_OUT, }; @@ -5782,6 +5792,7 @@ static const std::unordered_map k_offload_map { "l_out", OFFLOAD_FUNC }, { "result_norm", OFFLOAD_FUNC_EMB }, + { "result_output_no_bias", OFFLOAD_FUNC_EMB }, { "result_output", OFFLOAD_FUNC_OUT }, }; @@ -6235,12 +6246,16 @@ static int llama_decode_internal( ggml_allocr_alloc_graph(lctx.alloc, gf); - struct ggml_tensor * res = gf->nodes[gf->n_nodes - 1]; + // the output is always the last tensor in the graph + struct ggml_tensor * res = gf->nodes[gf->n_nodes - 1]; + GGML_ASSERT(strcmp(res->name, "result_output") == 0); + + // the embeddings could be the second to last tensor, or the third to last tensor struct ggml_tensor * embeddings = gf->nodes[gf->n_nodes - 2]; - - GGML_ASSERT(strcmp(res->name, "result_output") == 0); - GGML_ASSERT(strcmp(embeddings->name, "result_norm") == 0); - + if (strcmp(embeddings->name, "result_norm") != 0) { + embeddings = gf->nodes[gf->n_nodes - 3]; + GGML_ASSERT(strcmp(embeddings->name, "result_norm") == 0); + } #ifdef GGML_USE_CUBLAS for (int i = 0; i < gf->n_leafs; i++) { diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index df2c3fb6e..f04b9438a 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -1555,6 +1555,7 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op test_cases.emplace_back(new test_rope(type, { 64, 8, 10, 1}, 64, 2, 512)); // neox (falcon 40B) test_cases.emplace_back(new test_rope(type, { 64, 128, 10, 1}, 64, 2, 512)); // neox (falcon 40B) test_cases.emplace_back(new test_rope(type, { 80, 32, 10, 1}, 20, 2, 512)); // neox (stablelm) + test_cases.emplace_back(new test_rope(type, { 80, 32, 10, 1}, 32, 2, 512)); // neox (phi-2) } test_cases.emplace_back(new test_alibi()); From aa5c881adb143a378dfbe1a17dedf6337dad2ed1 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sat, 16 Dec 2023 10:54:10 +0200 Subject: [PATCH 07/13] phi-2 : use layer norm eps --- llama.cpp | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/llama.cpp b/llama.cpp index b4e62341b..1f7b02524 100644 --- a/llama.cpp +++ b/llama.cpp @@ -2646,6 +2646,15 @@ static void llm_load_hparams( default: model.type = e_model::MODEL_UNKNOWN; } } break; + case LLM_ARCH_PHI2: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); + + switch (hparams.n_layer) { + case 32: model.type = e_model::MODEL_3B; break; + default: model.type = e_model::MODEL_UNKNOWN; + } + } break; default: (void)0; } From 7500fa2f073403594efb13642b5c9093987f38ac Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sat, 16 Dec 2023 11:01:02 +0200 Subject: [PATCH 08/13] py : whitespaces --- convert-hf-to-gguf.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/convert-hf-to-gguf.py b/convert-hf-to-gguf.py index aed84f1b2..8767c89a8 100755 --- a/convert-hf-to-gguf.py +++ b/convert-hf-to-gguf.py @@ -984,6 +984,7 @@ class QwenModel(Model): print(f"{new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}") self.gguf_writer.add_tensor(new_name, data) + class Phi2Model(Model): def set_gguf_parameters(self): block_count = self.hparams["n_layer"] @@ -999,6 +1000,7 @@ class Phi2Model(Model): self.gguf_writer.add_rope_dimension_count(self.hparams["rotary_dim"]) self.gguf_writer.add_file_type(self.ftype) + ###### CONVERSION LOGIC ###### From 5469d82d5aa78c5f7e4d12d7be14c049659337db Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sat, 16 Dec 2023 11:19:56 +0200 Subject: [PATCH 09/13] llama : fix meta KV override bug --- llama.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/llama.cpp b/llama.cpp index 1f7b02524..df59c1e52 100644 --- a/llama.cpp +++ b/llama.cpp @@ -1949,7 +1949,7 @@ namespace GGUFMeta { target = override->bool_value; return true; } - return true; + return false; } template From a878be4cb1f525ab1ab6960774b81ba157cee169 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sat, 16 Dec 2023 11:20:11 +0200 Subject: [PATCH 10/13] convert : phi don't add BOS token --- convert-hf-to-gguf.py | 1 + 1 file changed, 1 insertion(+) diff --git a/convert-hf-to-gguf.py b/convert-hf-to-gguf.py index 8767c89a8..8aaf67384 100755 --- a/convert-hf-to-gguf.py +++ b/convert-hf-to-gguf.py @@ -999,6 +999,7 @@ class Phi2Model(Model): self.gguf_writer.add_layer_norm_eps(self.hparams["layer_norm_epsilon"]) self.gguf_writer.add_rope_dimension_count(self.hparams["rotary_dim"]) self.gguf_writer.add_file_type(self.ftype) + self.gguf_writer.add_add_bos_token(False) ###### CONVERSION LOGIC ###### From 0b6ffa580c5b6db743d21cab82a048dcfc496d6b Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sat, 16 Dec 2023 16:05:35 +0200 Subject: [PATCH 11/13] convert : revert "added_tokens_decoder" change --- convert-hf-to-gguf.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/convert-hf-to-gguf.py b/convert-hf-to-gguf.py index 8aaf67384..e71a96c48 100755 --- a/convert-hf-to-gguf.py +++ b/convert-hf-to-gguf.py @@ -249,7 +249,7 @@ class Model: toktypes.append(gguf.TokenType.USER_DEFINED) elif reverse_vocab[i] in added_vocab: tokens.append(reverse_vocab[i]) - if hasattr(tokenizer, "added_tokens_decoder") and tokenizer.added_tokens_decoder[i].special: + if tokenizer.added_tokens_decoder[i].special: toktypes.append(gguf.TokenType.CONTROL) else: toktypes.append(gguf.TokenType.USER_DEFINED) From 0644c3be514a8dccec6ba5c77f2392a7a2c0a5ca Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sat, 16 Dec 2023 18:01:08 +0200 Subject: [PATCH 12/13] phi-2 : scale Q instead of KQ for better precision --- llama.cpp | 53 ++++++++++++++++++++++++++++++++++++++++------------- 1 file changed, 40 insertions(+), 13 deletions(-) diff --git a/llama.cpp b/llama.cpp index df59c1e52..22343e882 100644 --- a/llama.cpp +++ b/llama.cpp @@ -4088,6 +4088,7 @@ static struct ggml_tensor * llm_build_kqv( int32_t n_tokens, int32_t n_kv, float max_alibi_bias, + float scale, const llm_build_cb & cb, int il) { const int64_t n_embd = hparams.n_embd; @@ -4129,7 +4130,7 @@ static struct ggml_tensor * llm_build_kqv( kq = ggml_soft_max(ctx, kq); cb(kq, "kq_soft_max", il); } else { - kq = ggml_soft_max_ext(ctx, kq, kq_mask, 1.0f/sqrtf(float(n_embd_head))); + kq = ggml_soft_max_ext(ctx, kq, kq_mask, scale); cb(kq, "kq_soft_max_ext", il); } @@ -4338,7 +4339,7 @@ struct llm_build_context { cur = llm_build_kqv(ctx0, hparams, kv_self, model.layers[il].wo, model.layers[il].bo, - Qcur, KQ_scale, KQ_mask, n_ctx, n_tokens, n_kv, -1.0f, cb, il); + Qcur, KQ_scale, KQ_mask, n_ctx, n_tokens, n_kv, -1.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il); cb(cur, "kqv_out", il); } @@ -4521,7 +4522,7 @@ struct llm_build_context { cur = llm_build_kqv(ctx0, hparams, kv_self, model.layers[il].wo, NULL, - Qcur, KQ_scale, KQ_mask, n_ctx, n_tokens, n_kv, max_alibi_bias, cb, il); + Qcur, KQ_scale, KQ_mask, n_ctx, n_tokens, n_kv, max_alibi_bias, 1.0f/sqrtf(float(n_embd_head)), cb, il); cb(cur, "kqv_out", il); } @@ -4645,7 +4646,7 @@ struct llm_build_context { cur = llm_build_kqv(ctx0, hparams, kv_self, model.layers[il].wo, NULL, - Qcur, KQ_scale, KQ_mask, n_ctx, n_tokens, n_kv, -1.0f, cb, il); + Qcur, KQ_scale, KQ_mask, n_ctx, n_tokens, n_kv, -1.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il); cb(cur, "kqv_out", il); } @@ -4745,7 +4746,7 @@ struct llm_build_context { cur = llm_build_kqv(ctx0, hparams, kv_self, model.layers[il].wo, model.layers[il].bo, - Qcur, KQ_scale, KQ_mask, n_ctx, n_tokens, n_kv, -1.0f, cb, il); + Qcur, KQ_scale, KQ_mask, n_ctx, n_tokens, n_kv, -1.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il); cb(cur, "kqv_out", il); } @@ -4954,7 +4955,7 @@ struct llm_build_context { // TODO: not tested, could be broken cur = llm_build_kqv(ctx0, hparams, kv_self, model.layers[il].wo, model.layers[il].bo, - Q, KQ_scale, KQ_mask, n_ctx, n_tokens, n_kv, -1.0f, cb, il); + Q, KQ_scale, KQ_mask, n_ctx, n_tokens, n_kv, -1.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il); cb(cur, "kqv_out", il); } @@ -5045,7 +5046,7 @@ struct llm_build_context { cur = llm_build_kqv(ctx0, hparams, kv_self, model.layers[il].wo, NULL, - Qcur, KQ_scale, KQ_mask, n_ctx, n_tokens, n_kv, 8.0f, cb, il); + Qcur, KQ_scale, KQ_mask, n_ctx, n_tokens, n_kv, 8.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il); cb(cur, "kqv_out", il); } @@ -5142,7 +5143,7 @@ struct llm_build_context { cur = llm_build_kqv(ctx0, hparams, kv_self, model.layers[il].wo, model.layers[il].bo, - Qcur, KQ_scale, KQ_mask, n_ctx, n_tokens, n_kv, 8.0f, cb, il); + Qcur, KQ_scale, KQ_mask, n_ctx, n_tokens, n_kv, 8.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il); cb(cur, "kqv_out", il); } @@ -5236,7 +5237,7 @@ struct llm_build_context { cur = llm_build_kqv(ctx0, hparams, kv_self, model.layers[il].wo, NULL, - Qcur, KQ_scale, KQ_mask, n_ctx, n_tokens, n_kv, hparams.f_max_alibi_bias, cb, il); + Qcur, KQ_scale, KQ_mask, n_ctx, n_tokens, n_kv, hparams.f_max_alibi_bias, 1.0f/sqrtf(float(n_embd_head)), cb, il); cb(cur, "kqv_out", il); } @@ -5349,7 +5350,7 @@ struct llm_build_context { cur = llm_build_kqv(ctx0, hparams, kv_self, model.layers[il].wo, NULL, - Qcur, KQ_scale, KQ_mask, n_ctx, n_tokens, n_kv, -1.0f, cb, il); + Qcur, KQ_scale, KQ_mask, n_ctx, n_tokens, n_kv, -1.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il); cb(cur, "kqv_out", il); } @@ -5466,7 +5467,7 @@ struct llm_build_context { cur = llm_build_kqv(ctx0, hparams, kv_self, model.layers[il].wo, NULL, - Qcur, KQ_scale, KQ_mask, n_ctx, n_tokens, n_kv, -1.0f, cb, il); + Qcur, KQ_scale, KQ_mask, n_ctx, n_tokens, n_kv, -1.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il); cb(cur, "kqv_out", il); } @@ -5525,6 +5526,10 @@ struct llm_build_context { struct ggml_tensor * inp_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens); cb(inp_pos, "inp_pos", -1); + // Q_scale + struct ggml_tensor * Q_scale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1); + cb(Q_scale, "Q_scale", -1); + // KQ_scale struct ggml_tensor * KQ_scale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1); cb(KQ_scale, "KQ_scale", -1); @@ -5570,6 +5575,9 @@ struct llm_build_context { ); cb(Qcur, "Qcur", il); + Qcur = ggml_scale(ctx0, Qcur, Q_scale); + cb(Qcur, "Qcur", il); + Kcur = ggml_rope_custom( ctx0, Kcur, inp_pos, hparams.n_rot, 2, 0, n_orig_ctx, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow @@ -5580,7 +5588,7 @@ struct llm_build_context { cur = llm_build_kqv(ctx0, hparams, kv_self, model.layers[il].wo, model.layers[il].bo, - Qcur, KQ_scale, KQ_mask, n_ctx, n_tokens, n_kv, -1.0f, cb, il); + Qcur, KQ_scale, KQ_mask, n_ctx, n_tokens, n_kv, -1.0f, 1.0f, cb, il); cb(cur, "kqv_out", il); } @@ -5717,6 +5725,7 @@ static const std::unordered_map k_offload_map { "pos_embd", OFFLOAD_FUNC_NR }, { "inp_pos", OFFLOAD_FUNC_FRC }, // this is often used for KQ ops (e.g. rope) + { "Q_scale", OFFLOAD_FUNC_FRC }, { "KQ_scale", OFFLOAD_FUNC_FRC }, { "KQ_mask", OFFLOAD_FUNC_FRC }, { "K_shift", OFFLOAD_FUNC_FRC }, @@ -5819,6 +5828,7 @@ static struct ggml_cgraph * llama_build_graph( bool alloc_inp_tokens = false; bool alloc_inp_embd = false; bool alloc_inp_pos = false; + bool alloc_inp_Q_scale = false; bool alloc_inp_KQ_scale = false; bool alloc_inp_KQ_mask = false; bool alloc_inp_K_shift = false; @@ -5886,7 +5896,7 @@ static struct ggml_cgraph * llama_build_graph( alloc_inp_pos = true; } - if (!alloc_inp_KQ_scale && strcmp(name, "KQ_scale") == 0) { + if (!alloc_inp_Q_scale && strcmp(name, "Q_scale") == 0) { ggml_allocr_alloc(lctx.alloc, cur); if (!ggml_allocr_is_measure(lctx.alloc)) { @@ -5894,6 +5904,23 @@ static struct ggml_cgraph * llama_build_graph( ggml_set_f32(cur, 1.0f/sqrtf(float(n_embd_head))); } + alloc_inp_Q_scale = true; + } + + if (!alloc_inp_KQ_scale && strcmp(name, "KQ_scale") == 0) { + ggml_allocr_alloc(lctx.alloc, cur); + + if (!ggml_allocr_is_measure(lctx.alloc)) { + const int64_t n_embd_head = model.hparams.n_embd_head(); + if (model.arch == LLM_ARCH_PHI2) { + // with phi2, we scale the Q to avoid precision issues + // ref: https://github.com/ml-explore/mlx-examples/blob/08e862336ade809bc37d1035f94b359e7d1a5152/phi2/phi2.py#L64-L66 + ggml_set_f32(cur, 1.0f); + } else { + ggml_set_f32(cur, 1.0f/sqrtf(float(n_embd_head))); + } + } + alloc_inp_KQ_scale = true; } From b672c169ca6c2737e45e9b851f7ee2ce48e8daa1 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 17 Dec 2023 08:39:18 +0200 Subject: [PATCH 13/13] ggml : fix NeoX rope to rotate just first n_dims --- ggml-cuda.cu | 46 ++++++++++++++++++++++------------------------ ggml-metal.metal | 13 +++++++++++-- ggml.c | 34 ++++++++++++++++++++++++++++------ 3 files changed, 61 insertions(+), 32 deletions(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index b2e086a19..29df3904f 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -4998,31 +4998,29 @@ static __global__ void rope_neox( const int ib = col / n_dims; const int ic = col % n_dims; - // IMPORTANT: consider the case ncols == 80 and n_dims == 32 (phi-2) - // I don't know what we are supposed to compute, because the row is not divisible by n_dims - // this check matches the CPU code, but it is likely wrong as well - // I can't understand the Python code, so if you know what to do here, please fix it - // ref: https://github.com/ml-explore/mlx/blob/dc2edc762c797e3b8de50b1dad4dc0a131691033/benchmarks/python/llama_jax_bench.py#L11-L26 - if (ncols % n_dims != 0 && ib == ncols/n_dims) { - return; + if (ib == 0) { + const int i = row*ncols + ib*n_dims + ic/2; + const int i2 = row/p_delta_rows; + + float cur_rot = inv_ndims * ic - ib; + + const int p = has_pos ? pos[i2] : 0; + const float theta_base = p*freq_scale*powf(theta_scale, col/2.0f); + + float cos_theta, sin_theta; + rope_yarn(theta_base, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor, &cos_theta, &sin_theta); + + const float x0 = x[i + 0]; + const float x1 = x[i + n_dims/2]; + + dst[i + 0] = x0*cos_theta - x1*sin_theta; + dst[i + n_dims/2] = x0*sin_theta + x1*cos_theta; + } else { + const int i = row*ncols + ib*n_dims + ic; + + dst[i + 0] = x[i + 0]; + dst[i + 1] = x[i + 1]; } - - const int i = row*ncols + ib*n_dims + ic/2; - const int i2 = row/p_delta_rows; - - float cur_rot = inv_ndims * ic - ib; - - const int p = has_pos ? pos[i2] : 0; - const float theta_base = p*freq_scale*powf(theta_scale, col/2.0f); - - float cos_theta, sin_theta; - rope_yarn(theta_base, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor, &cos_theta, &sin_theta); - - const float x0 = x[i + 0]; - const float x1 = x[i + n_dims/2]; - - dst[i + 0] = x0*cos_theta - x1*sin_theta; - dst[i + n_dims/2] = x0*sin_theta + x1*cos_theta; } static __global__ void rope_glm_f32( diff --git a/ggml-metal.metal b/ggml-metal.metal index fe0ada445..d5b54e112 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -1702,8 +1702,9 @@ kernel void kernel_rope( dst_data[1] = x0*sin_theta + x1*cos_theta; } } else { - for (int64_t ib = 0; ib < ne0/n_dims; ++ib) { - for (int64_t ic = 2*tiitg; ic < n_dims; ic += 2*tptg.x) { + for (int64_t ic = 2*tiitg; ic < ne0; ic += 2*tptg.x) { + if (ic < n_dims) { + const int64_t ib = 0; // simplified from `(ib * n_dims + ic) * inv_ndims` const float cur_rot = inv_ndims*ic - ib; @@ -1722,6 +1723,14 @@ kernel void kernel_rope( dst_data[0] = x0*cos_theta - x1*sin_theta; dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta; + } else { + const int64_t i0 = ic; + + device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); + device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + + dst_data[0] = src[0]; + dst_data[1] = src[1]; } } } diff --git a/ggml.c b/ggml.c index 772be5616..08e8a8c1d 100644 --- a/ggml.c +++ b/ggml.c @@ -11408,10 +11408,13 @@ static void ggml_compute_forward_rope_f32( } } else { // TODO: this might be wrong for ne0 != n_dims - need double check - // ref: https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt_neox/modeling_gpt_neox.py#LL251C1-L294C28 + // it seems we have to rope just the first n_dims elements and do nothing with the rest + // ref: https://github.com/ml-explore/mlx/blob/dc2edc762c797e3b8de50b1dad4dc0a131691033/benchmarks/python/llama_jax_bench.py#L11-L26 theta_base *= freq_scale; - for (int64_t ib = 0; ib < ne0/n_dims; ++ib) { - for (int64_t ic = 0; ic < n_dims; ic += 2) { + for (int64_t ic = 0; ic < ne0; ic += 2) { + if (ic < n_dims) { + const int64_t ib = 0; + // simplified from `(ib * n_dims + ic) * inv_ndims` float cur_rot = inv_ndims * ic - ib; @@ -11434,6 +11437,14 @@ static void ggml_compute_forward_rope_f32( dst_data[0] = x0*cos_theta - x1*sin_theta; dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta; + } else { + const int64_t i0 = ic; + + const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); + float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + + dst_data[0] = src[0]; + dst_data[1] = src[1]; } } } @@ -11561,10 +11572,13 @@ static void ggml_compute_forward_rope_f16( } } else { // TODO: this might be wrong for ne0 != n_dims - need double check - // ref: https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt_neox/modeling_gpt_neox.py#LL251C1-L294C28 + // it seems we have to rope just the first n_dims elements and do nothing with the rest + // ref: https://github.com/ml-explore/mlx/blob/dc2edc762c797e3b8de50b1dad4dc0a131691033/benchmarks/python/llama_jax_bench.py#L11-L26 theta_base *= freq_scale; - for (int64_t ib = 0; ib < ne0/n_dims; ++ib) { - for (int64_t ic = 0; ic < n_dims; ic += 2) { + for (int64_t ic = 0; ic < ne0; ic += 2) { + if (ic < n_dims) { + const int64_t ib = 0; + // simplified from `(ib * n_dims + ic) * inv_ndims` float cur_rot = inv_ndims * ic - ib; @@ -11587,6 +11601,14 @@ static void ggml_compute_forward_rope_f16( dst_data[0] = GGML_FP32_TO_FP16(x0*cos_theta - x1*sin_theta); dst_data[n_dims/2] = GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta); + } else { + const int64_t i0 = ic; + + const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); + ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + + dst_data[0] = src[0]; + dst_data[1] = src[1]; } } }