Merge remote-tracking branch 'origin/master' into sl/ggml-backend-int
This commit is contained in:
commit
94507911bb
11 changed files with 485 additions and 79 deletions
|
@ -182,6 +182,8 @@ class Model:
|
||||||
return QwenModel
|
return QwenModel
|
||||||
if model_architecture == "MixtralForCausalLM":
|
if model_architecture == "MixtralForCausalLM":
|
||||||
return MixtralModel
|
return MixtralModel
|
||||||
|
if model_architecture == "PhiForCausalLM":
|
||||||
|
return Phi2Model
|
||||||
return Model
|
return Model
|
||||||
|
|
||||||
def _is_model_safetensors(self) -> bool:
|
def _is_model_safetensors(self) -> bool:
|
||||||
|
@ -221,6 +223,8 @@ class Model:
|
||||||
return gguf.MODEL_ARCH.QWEN
|
return gguf.MODEL_ARCH.QWEN
|
||||||
if arch == "MixtralForCausalLM":
|
if arch == "MixtralForCausalLM":
|
||||||
return gguf.MODEL_ARCH.LLAMA
|
return gguf.MODEL_ARCH.LLAMA
|
||||||
|
if arch == "PhiForCausalLM":
|
||||||
|
return gguf.MODEL_ARCH.PHI2
|
||||||
|
|
||||||
raise NotImplementedError(f'Architecture "{arch}" not supported!')
|
raise NotImplementedError(f'Architecture "{arch}" not supported!')
|
||||||
|
|
||||||
|
@ -980,6 +984,24 @@ class QwenModel(Model):
|
||||||
print(f"{new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}")
|
print(f"{new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}")
|
||||||
self.gguf_writer.add_tensor(new_name, data)
|
self.gguf_writer.add_tensor(new_name, data)
|
||||||
|
|
||||||
|
|
||||||
|
class Phi2Model(Model):
|
||||||
|
def set_gguf_parameters(self):
|
||||||
|
block_count = self.hparams["n_layer"]
|
||||||
|
|
||||||
|
self.gguf_writer.add_name("Phi2")
|
||||||
|
self.gguf_writer.add_context_length(self.hparams["n_positions"])
|
||||||
|
self.gguf_writer.add_embedding_length(self.hparams["n_embd"])
|
||||||
|
self.gguf_writer.add_feed_forward_length(4 * self.hparams["n_embd"])
|
||||||
|
self.gguf_writer.add_block_count(block_count)
|
||||||
|
self.gguf_writer.add_head_count(self.hparams["n_head"])
|
||||||
|
self.gguf_writer.add_head_count_kv(self.hparams["n_head"])
|
||||||
|
self.gguf_writer.add_layer_norm_eps(self.hparams["layer_norm_epsilon"])
|
||||||
|
self.gguf_writer.add_rope_dimension_count(self.hparams["rotary_dim"])
|
||||||
|
self.gguf_writer.add_file_type(self.ftype)
|
||||||
|
self.gguf_writer.add_add_bos_token(False)
|
||||||
|
|
||||||
|
|
||||||
###### CONVERSION LOGIC ######
|
###### CONVERSION LOGIC ######
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -203,7 +203,7 @@ actor LlamaContext {
|
||||||
var pp_std: Double = 0
|
var pp_std: Double = 0
|
||||||
var tg_std: Double = 0
|
var tg_std: Double = 0
|
||||||
|
|
||||||
for r in 0..<nr {
|
for _ in 0..<nr {
|
||||||
// bench prompt processing
|
// bench prompt processing
|
||||||
|
|
||||||
llama_batch_clear(&batch)
|
llama_batch_clear(&batch)
|
||||||
|
|
|
@ -75,21 +75,56 @@ struct ContentView: View {
|
||||||
VStack {
|
VStack {
|
||||||
DownloadButton(
|
DownloadButton(
|
||||||
llamaState: llamaState,
|
llamaState: llamaState,
|
||||||
modelName: "TinyLlama-1.1B (Q4_0)",
|
modelName: "TinyLlama-1.1B (Q4_0, 0.6 GiB)",
|
||||||
modelUrl: "https://huggingface.co/TheBloke/TinyLlama-1.1B-1T-OpenOrca-GGUF/resolve/main/tinyllama-1.1b-1t-openorca.Q4_0.gguf?download=true",
|
modelUrl: "https://huggingface.co/TheBloke/TinyLlama-1.1B-1T-OpenOrca-GGUF/resolve/main/tinyllama-1.1b-1t-openorca.Q4_0.gguf?download=true",
|
||||||
filename: "tinyllama-1.1b-1t-openorca.Q4_0.gguf"
|
filename: "tinyllama-1.1b-1t-openorca.Q4_0.gguf"
|
||||||
)
|
)
|
||||||
.font(.system(size: 12))
|
.font(.system(size: 12))
|
||||||
.padding(.top, 4)
|
.padding(.top, 4)
|
||||||
|
.frame(maxWidth: .infinity, alignment: .leading)
|
||||||
|
|
||||||
DownloadButton(
|
DownloadButton(
|
||||||
llamaState: llamaState,
|
llamaState: llamaState,
|
||||||
modelName: "TinyLlama-1.1B (Q8_0)",
|
modelName: "TinyLlama-1.1B (Q8_0, 1.1 GiB)",
|
||||||
modelUrl: "https://huggingface.co/TheBloke/TinyLlama-1.1B-1T-OpenOrca-GGUF/resolve/main/tinyllama-1.1b-1t-openorca.Q8_0.gguf?download=true",
|
modelUrl: "https://huggingface.co/TheBloke/TinyLlama-1.1B-1T-OpenOrca-GGUF/resolve/main/tinyllama-1.1b-1t-openorca.Q8_0.gguf?download=true",
|
||||||
filename: "tinyllama-1.1b-1t-openorca.Q8_0.gguf"
|
filename: "tinyllama-1.1b-1t-openorca.Q8_0.gguf"
|
||||||
)
|
)
|
||||||
.font(.system(size: 12))
|
.font(.system(size: 12))
|
||||||
|
|
||||||
|
DownloadButton(
|
||||||
|
llamaState: llamaState,
|
||||||
|
modelName: "TinyLlama-1.1B (F16, 2.2 GiB)",
|
||||||
|
modelUrl: "https://huggingface.co/ggml-org/models/resolve/main/tinyllama-1.1b/ggml-model-f16.gguf?download=true",
|
||||||
|
filename: "tinyllama-1.1b-f16.gguf"
|
||||||
|
)
|
||||||
|
.font(.system(size: 12))
|
||||||
|
.frame(maxWidth: .infinity, alignment: .leading)
|
||||||
|
|
||||||
|
DownloadButton(
|
||||||
|
llamaState: llamaState,
|
||||||
|
modelName: "Phi-2.7B (Q4_0, 1.6 GiB)",
|
||||||
|
modelUrl: "https://huggingface.co/ggml-org/models/resolve/main/phi-2/ggml-model-q4_0.gguf?download=true",
|
||||||
|
filename: "phi-2-q4_0.gguf"
|
||||||
|
)
|
||||||
|
.font(.system(size: 12))
|
||||||
|
|
||||||
|
DownloadButton(
|
||||||
|
llamaState: llamaState,
|
||||||
|
modelName: "Phi-2.7B (Q8_0, 2.8 GiB)",
|
||||||
|
modelUrl: "https://huggingface.co/ggml-org/models/resolve/main/phi-2/ggml-model-q8_0.gguf?download=true",
|
||||||
|
filename: "phi-2-q8_0.gguf"
|
||||||
|
)
|
||||||
|
.font(.system(size: 12))
|
||||||
|
.frame(maxWidth: .infinity, alignment: .leading)
|
||||||
|
|
||||||
|
DownloadButton(
|
||||||
|
llamaState: llamaState,
|
||||||
|
modelName: "Mistral-7B-v0.1 (Q4_0, 3.8 GiB)",
|
||||||
|
modelUrl: "https://huggingface.co/TheBloke/Mistral-7B-v0.1-GGUF/resolve/main/mistral-7b-v0.1.Q4_0.gguf?download=true",
|
||||||
|
filename: "mistral-7b-v0.1.Q4_0.gguf"
|
||||||
|
)
|
||||||
|
.font(.system(size: 12))
|
||||||
|
|
||||||
Button("Clear downloaded models") {
|
Button("Clear downloaded models") {
|
||||||
ContentView.cleanupModelCaches()
|
ContentView.cleanupModelCaches()
|
||||||
llamaState.cacheCleared = true
|
llamaState.cacheCleared = true
|
||||||
|
|
105
ggml-cuda.cu
105
ggml-cuda.cu
|
@ -31,6 +31,7 @@
|
||||||
#define CUDA_R_16F HIPBLAS_R_16F
|
#define CUDA_R_16F HIPBLAS_R_16F
|
||||||
#define CUDA_R_32F HIPBLAS_R_32F
|
#define CUDA_R_32F HIPBLAS_R_32F
|
||||||
#define __shfl_xor_sync(mask, var, laneMask, width) __shfl_xor(var, laneMask, width)
|
#define __shfl_xor_sync(mask, var, laneMask, width) __shfl_xor(var, laneMask, width)
|
||||||
|
#define cublasComputeType_t hipblasDatatype_t //deprecated, new hipblasComputeType_t not in 5.6
|
||||||
#define cublasCreate hipblasCreate
|
#define cublasCreate hipblasCreate
|
||||||
#define cublasGemmEx hipblasGemmEx
|
#define cublasGemmEx hipblasGemmEx
|
||||||
#define cublasGemmBatchedEx hipblasGemmBatchedEx
|
#define cublasGemmBatchedEx hipblasGemmBatchedEx
|
||||||
|
@ -40,6 +41,7 @@
|
||||||
#define cublasSetStream hipblasSetStream
|
#define cublasSetStream hipblasSetStream
|
||||||
#define cublasSgemm hipblasSgemm
|
#define cublasSgemm hipblasSgemm
|
||||||
#define cublasStatus_t hipblasStatus_t
|
#define cublasStatus_t hipblasStatus_t
|
||||||
|
#define cudaDataType_t hipblasDatatype_t //deprecated, new hipblasDatatype not in 5.6
|
||||||
#define cudaDeviceCanAccessPeer hipDeviceCanAccessPeer
|
#define cudaDeviceCanAccessPeer hipDeviceCanAccessPeer
|
||||||
#define cudaDeviceDisablePeerAccess hipDeviceDisablePeerAccess
|
#define cudaDeviceDisablePeerAccess hipDeviceDisablePeerAccess
|
||||||
#define cudaDeviceEnablePeerAccess hipDeviceEnablePeerAccess
|
#define cudaDeviceEnablePeerAccess hipDeviceEnablePeerAccess
|
||||||
|
@ -4998,6 +5000,15 @@ static __global__ void rope_neox(
|
||||||
const int ib = col / n_dims;
|
const int ib = col / n_dims;
|
||||||
const int ic = col % n_dims;
|
const int ic = col % n_dims;
|
||||||
|
|
||||||
|
if (ib > 0) {
|
||||||
|
const int i = row*ncols + ib*n_dims + ic;
|
||||||
|
|
||||||
|
dst[i + 0] = x[i + 0];
|
||||||
|
dst[i + 1] = x[i + 1];
|
||||||
|
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
const int i = row*ncols + ib*n_dims + ic/2;
|
const int i = row*ncols + ib*n_dims + ic/2;
|
||||||
const int i2 = row/p_delta_rows;
|
const int i2 = row/p_delta_rows;
|
||||||
|
|
||||||
|
@ -7378,7 +7389,7 @@ inline void ggml_cuda_op_mul_mat_cublas(
|
||||||
|
|
||||||
const int compute_capability = g_compute_capabilities[id];
|
const int compute_capability = g_compute_capabilities[id];
|
||||||
|
|
||||||
if (compute_capability >= CC_VOLTA && (src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) && ggml_is_contiguous(src0) && row_diff == src0->ne[1]) {
|
if (compute_capability >= CC_VOLTA && (src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) && ggml_is_contiguous(src0) && row_diff == src0->ne[1] && dst->op_params[0] == GGML_PREC_DEFAULT) {
|
||||||
// convert src0 and src1 to fp16, multiply as fp16, convert dst to fp32
|
// convert src0 and src1 to fp16, multiply as fp16, convert dst to fp32
|
||||||
half * src0_as_f16 = nullptr;
|
half * src0_as_f16 = nullptr;
|
||||||
size_t src0_as = 0;
|
size_t src0_as = 0;
|
||||||
|
@ -8302,27 +8313,27 @@ static void ggml_cuda_mul_mat_vec_nc(const ggml_tensor * src0, const ggml_tensor
|
||||||
}
|
}
|
||||||
|
|
||||||
static __global__ void k_compute_batched_ptrs(
|
static __global__ void k_compute_batched_ptrs(
|
||||||
const half * src0_as_f16, const half * src1_as_f16, half * dst_f16,
|
const half * src0_as_f16, const half * src1_as_f16, char * dst,
|
||||||
const void ** ptrs_src, void ** ptrs_dst,
|
const void ** ptrs_src, void ** ptrs_dst,
|
||||||
int ne12, int ne13,
|
int64_t ne12, int64_t ne13,
|
||||||
int ne23,
|
int64_t ne23,
|
||||||
int nb02, int nb03,
|
size_t nb02, size_t nb03,
|
||||||
int nb12, int nb13,
|
size_t nb12, size_t nb13,
|
||||||
int nb2, int nb3,
|
size_t nbd2, size_t nbd3,
|
||||||
int r2, int r3) {
|
int64_t r2, int64_t r3) {
|
||||||
int i13 = blockIdx.x * blockDim.x + threadIdx.x;
|
int64_t i13 = blockIdx.x * blockDim.x + threadIdx.x;
|
||||||
int i12 = blockIdx.y * blockDim.y + threadIdx.y;
|
int64_t i12 = blockIdx.y * blockDim.y + threadIdx.y;
|
||||||
|
|
||||||
if (i13 >= ne13 || i12 >= ne12) {
|
if (i13 >= ne13 || i12 >= ne12) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
int i03 = i13 / r3;
|
int64_t i03 = i13 / r3;
|
||||||
int i02 = i12 / r2;
|
int64_t i02 = i12 / r2;
|
||||||
|
|
||||||
ptrs_src[0*ne23 + i12 + i13*ne12] = (const char *) src0_as_f16 + i02*nb02 + i03*nb03;
|
ptrs_src[0*ne23 + i12 + i13*ne12] = (const char *) src0_as_f16 + i02*nb02 + i03*nb03;
|
||||||
ptrs_src[1*ne23 + i12 + i13*ne12] = (const char *) src1_as_f16 + i12*nb12/2 + i13*nb13/2;
|
ptrs_src[1*ne23 + i12 + i13*ne12] = (const char *) src1_as_f16 + i12*nb12/2 + i13*nb13/2;
|
||||||
ptrs_dst[0*ne23 + i12 + i13*ne12] = ( char *) dst_f16 + i12* nb2/2 + i13* nb3/2;
|
ptrs_dst[0*ne23 + i12 + i13*ne12] = ( char *) dst + i12*nbd2 + i13*nbd3;
|
||||||
}
|
}
|
||||||
|
|
||||||
static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||||
|
@ -8378,7 +8389,41 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const
|
||||||
to_fp16_cuda(src1_ddf, src1_as_f16, ne1, main_stream);
|
to_fp16_cuda(src1_ddf, src1_as_f16, ne1, main_stream);
|
||||||
|
|
||||||
size_t dst_as = 0;
|
size_t dst_as = 0;
|
||||||
half * dst_f16 = (half *) ggml_cuda_pool_malloc(ne * sizeof(half), &dst_as);
|
|
||||||
|
half * dst_f16 = nullptr;
|
||||||
|
char * dst_t = nullptr;
|
||||||
|
|
||||||
|
cublasComputeType_t cu_compute_type = CUBLAS_COMPUTE_16F;
|
||||||
|
cudaDataType_t cu_data_type = CUDA_R_16F;
|
||||||
|
|
||||||
|
// dst strides
|
||||||
|
size_t nbd2 = dst->nb[2];
|
||||||
|
size_t nbd3 = dst->nb[3];
|
||||||
|
|
||||||
|
const half alpha_f16 = 1.0f;
|
||||||
|
const half beta_f16 = 0.0f;
|
||||||
|
|
||||||
|
const float alpha_f32 = 1.0f;
|
||||||
|
const float beta_f32 = 0.0f;
|
||||||
|
|
||||||
|
const void * alpha = &alpha_f16;
|
||||||
|
const void * beta = &beta_f16;
|
||||||
|
|
||||||
|
if (dst->op_params[0] == GGML_PREC_DEFAULT) {
|
||||||
|
dst_f16 = (half *) ggml_cuda_pool_malloc(ne * sizeof(half), &dst_as);
|
||||||
|
dst_t = (char *) dst_f16;
|
||||||
|
|
||||||
|
nbd2 /= sizeof(float) / sizeof(half);
|
||||||
|
nbd3 /= sizeof(float) / sizeof(half);
|
||||||
|
} else {
|
||||||
|
dst_t = (char *) dst_ddf;
|
||||||
|
|
||||||
|
cu_compute_type = CUBLAS_COMPUTE_32F;
|
||||||
|
cu_data_type = CUDA_R_32F;
|
||||||
|
|
||||||
|
alpha = &alpha_f32;
|
||||||
|
beta = &beta_f32;
|
||||||
|
}
|
||||||
|
|
||||||
GGML_ASSERT(ne12 % ne02 == 0);
|
GGML_ASSERT(ne12 % ne02 == 0);
|
||||||
GGML_ASSERT(ne13 % ne03 == 0);
|
GGML_ASSERT(ne13 % ne03 == 0);
|
||||||
|
@ -8387,9 +8432,6 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const
|
||||||
const int64_t r2 = ne12/ne02;
|
const int64_t r2 = ne12/ne02;
|
||||||
const int64_t r3 = ne13/ne03;
|
const int64_t r3 = ne13/ne03;
|
||||||
|
|
||||||
const half alpha_f16 = 1.0f;
|
|
||||||
const half beta_f16 = 0.0f;
|
|
||||||
|
|
||||||
#if 0
|
#if 0
|
||||||
// use cublasGemmEx
|
// use cublasGemmEx
|
||||||
{
|
{
|
||||||
|
@ -8399,12 +8441,12 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const
|
||||||
int i02 = i12 / r2;
|
int i02 = i12 / r2;
|
||||||
|
|
||||||
CUBLAS_CHECK(
|
CUBLAS_CHECK(
|
||||||
cublasGemmEx(g_cublas_handles[id], CUBLAS_OP_T, CUBLAS_OP_N,
|
cublasGemmEx(g_cublas_handles[g_main_device], CUBLAS_OP_T, CUBLAS_OP_N,
|
||||||
ne01, ne11, ne10,
|
ne01, ne11, ne10,
|
||||||
&alpha_f16, (const char *) src0_as_f16 + i02*src0->nb[2] + i03*src0->nb[3] , CUDA_R_16F, nb01/sizeof(half),
|
alpha, (const char *) src0_as_f16 + i02*src0->nb[2] + i03*src0->nb[3] , CUDA_R_16F, nb01/sizeof(half),
|
||||||
(const char *) src1_as_f16 + i12*src1->nb[2]/2 + i13*src1->nb[3]/2, CUDA_R_16F, nb11/sizeof(float),
|
(const char *) src1_as_f16 + i12*src1->nb[2]/2 + i13*src1->nb[3]/2, CUDA_R_16F, nb11/sizeof(float),
|
||||||
&beta_f16, ( char *) dst_f16 + i12* dst->nb[2]/2 + i13* dst->nb[3]/2, CUDA_R_16F, ne01,
|
beta, ( char *) dst_t + i12*nbd2 + i13*nbd3, cu_data_type, ne01,
|
||||||
CUBLAS_COMPUTE_16F,
|
cu_compute_type,
|
||||||
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
|
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -8416,11 +8458,11 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const
|
||||||
CUBLAS_CHECK(
|
CUBLAS_CHECK(
|
||||||
cublasGemmStridedBatchedEx(g_cublas_handles[g_main_device], CUBLAS_OP_T, CUBLAS_OP_N,
|
cublasGemmStridedBatchedEx(g_cublas_handles[g_main_device], CUBLAS_OP_T, CUBLAS_OP_N,
|
||||||
ne01, ne11, ne10,
|
ne01, ne11, ne10,
|
||||||
&alpha_f16, (const char *) src0_as_f16, CUDA_R_16F, nb01/sizeof(half), src0->nb[2]/sizeof(half), // strideA
|
alpha, (const char *) src0_as_f16, CUDA_R_16F, nb01/sizeof(half), src0->nb[2]/sizeof(half), // strideA
|
||||||
(const char *) src1_as_f16, CUDA_R_16F, nb11/sizeof(float), src1->nb[2]/sizeof(float), // strideB
|
(const char *) src1_as_f16, CUDA_R_16F, nb11/sizeof(float), src1->nb[2]/sizeof(float), // strideB
|
||||||
&beta_f16, ( char *) dst_f16, CUDA_R_16F, ne01, dst->nb[2]/sizeof(float), // strideC
|
beta, ( char *) dst_t, cu_data_type, ne01, dst->nb[2]/sizeof(float), // strideC
|
||||||
ne12*ne13,
|
ne12*ne13,
|
||||||
CUBLAS_COMPUTE_16F,
|
cu_compute_type,
|
||||||
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
|
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
|
||||||
} else {
|
} else {
|
||||||
// use cublasGemmBatchedEx
|
// use cublasGemmBatchedEx
|
||||||
|
@ -8437,24 +8479,24 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const
|
||||||
|
|
||||||
dim3 block_dims(ne13, ne12);
|
dim3 block_dims(ne13, ne12);
|
||||||
k_compute_batched_ptrs<<<1, block_dims, 0, main_stream>>>(
|
k_compute_batched_ptrs<<<1, block_dims, 0, main_stream>>>(
|
||||||
src0_as_f16, src1_as_f16, dst_f16,
|
src0_as_f16, src1_as_f16, dst_t,
|
||||||
ptrs_src, ptrs_dst,
|
ptrs_src, ptrs_dst,
|
||||||
ne12, ne13,
|
ne12, ne13,
|
||||||
ne23,
|
ne23,
|
||||||
nb02, nb03,
|
nb02, nb03,
|
||||||
nb12, nb13,
|
nb12, nb13,
|
||||||
dst->nb[2], dst->nb[3],
|
nbd2, nbd3,
|
||||||
r2, r3);
|
r2, r3);
|
||||||
CUDA_CHECK(cudaGetLastError());
|
CUDA_CHECK(cudaGetLastError());
|
||||||
|
|
||||||
CUBLAS_CHECK(
|
CUBLAS_CHECK(
|
||||||
cublasGemmBatchedEx(g_cublas_handles[g_main_device], CUBLAS_OP_T, CUBLAS_OP_N,
|
cublasGemmBatchedEx(g_cublas_handles[g_main_device], CUBLAS_OP_T, CUBLAS_OP_N,
|
||||||
ne01, ne11, ne10,
|
ne01, ne11, ne10,
|
||||||
&alpha_f16, (const void **) (ptrs_src + 0*ne23), CUDA_R_16F, nb01/sizeof(half),
|
alpha, (const void **) (ptrs_src + 0*ne23), CUDA_R_16F, nb01/sizeof(half),
|
||||||
(const void **) (ptrs_src + 1*ne23), CUDA_R_16F, nb11/sizeof(float),
|
(const void **) (ptrs_src + 1*ne23), CUDA_R_16F, nb11/sizeof(float),
|
||||||
&beta_f16, ( void **) (ptrs_dst + 0*ne23), CUDA_R_16F, ne01,
|
beta, ( void **) (ptrs_dst + 0*ne23), cu_data_type, ne01,
|
||||||
ne23,
|
ne23,
|
||||||
CUBLAS_COMPUTE_16F,
|
cu_compute_type,
|
||||||
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
|
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
|
||||||
|
|
||||||
if (ptrs_src_s != 0) {
|
if (ptrs_src_s != 0) {
|
||||||
|
@ -8466,11 +8508,14 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
if (dst->op_params[0] == GGML_PREC_DEFAULT) {
|
||||||
const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16);
|
const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16);
|
||||||
to_fp32_cuda(dst_f16, dst_ddf, ne, main_stream);
|
to_fp32_cuda(dst_f16, dst_ddf, ne, main_stream);
|
||||||
|
|
||||||
ggml_cuda_pool_free(src1_as_f16, src1_as);
|
|
||||||
ggml_cuda_pool_free(dst_f16, dst_as);
|
ggml_cuda_pool_free(dst_f16, dst_as);
|
||||||
|
}
|
||||||
|
|
||||||
|
ggml_cuda_pool_free(src1_as_f16, src1_as);
|
||||||
}
|
}
|
||||||
|
|
||||||
static void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
static void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||||
|
|
|
@ -1702,8 +1702,9 @@ kernel void kernel_rope(
|
||||||
dst_data[1] = x0*sin_theta + x1*cos_theta;
|
dst_data[1] = x0*sin_theta + x1*cos_theta;
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
for (int64_t ib = 0; ib < ne0/n_dims; ++ib) {
|
for (int64_t ic = 2*tiitg; ic < ne0; ic += 2*tptg.x) {
|
||||||
for (int64_t ic = 2*tiitg; ic < n_dims; ic += 2*tptg.x) {
|
if (ic < n_dims) {
|
||||||
|
const int64_t ib = 0;
|
||||||
|
|
||||||
// simplified from `(ib * n_dims + ic) * inv_ndims`
|
// simplified from `(ib * n_dims + ic) * inv_ndims`
|
||||||
const float cur_rot = inv_ndims*ic - ib;
|
const float cur_rot = inv_ndims*ic - ib;
|
||||||
|
@ -1722,6 +1723,14 @@ kernel void kernel_rope(
|
||||||
|
|
||||||
dst_data[0] = x0*cos_theta - x1*sin_theta;
|
dst_data[0] = x0*cos_theta - x1*sin_theta;
|
||||||
dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta;
|
dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta;
|
||||||
|
} else {
|
||||||
|
const int64_t i0 = ic;
|
||||||
|
|
||||||
|
device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
|
||||||
|
device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
||||||
|
|
||||||
|
dst_data[0] = src[0];
|
||||||
|
dst_data[1] = src[1];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
46
ggml.c
46
ggml.c
|
@ -4086,6 +4086,14 @@ struct ggml_tensor * ggml_mul_mat(
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void ggml_mul_mat_set_prec(
|
||||||
|
struct ggml_tensor * a,
|
||||||
|
enum ggml_prec prec) {
|
||||||
|
const int32_t prec_i32 = (int32_t) prec;
|
||||||
|
|
||||||
|
ggml_set_op_params_i32(a, 0, prec_i32);
|
||||||
|
}
|
||||||
|
|
||||||
// ggml_mul_mat_id
|
// ggml_mul_mat_id
|
||||||
|
|
||||||
struct ggml_tensor * ggml_mul_mat_id(
|
struct ggml_tensor * ggml_mul_mat_id(
|
||||||
|
@ -9156,6 +9164,8 @@ static void ggml_compute_forward_norm_f32(
|
||||||
float eps;
|
float eps;
|
||||||
memcpy(&eps, dst->op_params, sizeof(float));
|
memcpy(&eps, dst->op_params, sizeof(float));
|
||||||
|
|
||||||
|
GGML_ASSERT(eps > 0.0f);
|
||||||
|
|
||||||
// TODO: optimize
|
// TODO: optimize
|
||||||
for (int64_t i03 = 0; i03 < ne03; i03++) {
|
for (int64_t i03 = 0; i03 < ne03; i03++) {
|
||||||
for (int64_t i02 = 0; i02 < ne02; i02++) {
|
for (int64_t i02 = 0; i02 < ne02; i02++) {
|
||||||
|
@ -9225,6 +9235,8 @@ static void ggml_compute_forward_rms_norm_f32(
|
||||||
float eps;
|
float eps;
|
||||||
memcpy(&eps, dst->op_params, sizeof(float));
|
memcpy(&eps, dst->op_params, sizeof(float));
|
||||||
|
|
||||||
|
GGML_ASSERT(eps > 0.0f);
|
||||||
|
|
||||||
// TODO: optimize
|
// TODO: optimize
|
||||||
for (int64_t i03 = 0; i03 < ne03; i03++) {
|
for (int64_t i03 = 0; i03 < ne03; i03++) {
|
||||||
for (int64_t i02 = 0; i02 < ne02; i02++) {
|
for (int64_t i02 = 0; i02 < ne02; i02++) {
|
||||||
|
@ -11550,10 +11562,13 @@ static void ggml_compute_forward_rope_f32(
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// TODO: this might be wrong for ne0 != n_dims - need double check
|
// TODO: this might be wrong for ne0 != n_dims - need double check
|
||||||
// ref: https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt_neox/modeling_gpt_neox.py#LL251C1-L294C28
|
// it seems we have to rope just the first n_dims elements and do nothing with the rest
|
||||||
|
// ref: https://github.com/ml-explore/mlx/blob/dc2edc762c797e3b8de50b1dad4dc0a131691033/benchmarks/python/llama_jax_bench.py#L11-L26
|
||||||
theta_base *= freq_scale;
|
theta_base *= freq_scale;
|
||||||
for (int64_t ib = 0; ib < ne0/n_dims; ++ib) {
|
for (int64_t ic = 0; ic < ne0; ic += 2) {
|
||||||
for (int64_t ic = 0; ic < n_dims; ic += 2) {
|
if (ic < n_dims) {
|
||||||
|
const int64_t ib = 0;
|
||||||
|
|
||||||
// simplified from `(ib * n_dims + ic) * inv_ndims`
|
// simplified from `(ib * n_dims + ic) * inv_ndims`
|
||||||
float cur_rot = inv_ndims * ic - ib;
|
float cur_rot = inv_ndims * ic - ib;
|
||||||
|
|
||||||
|
@ -11576,6 +11591,14 @@ static void ggml_compute_forward_rope_f32(
|
||||||
|
|
||||||
dst_data[0] = x0*cos_theta - x1*sin_theta;
|
dst_data[0] = x0*cos_theta - x1*sin_theta;
|
||||||
dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta;
|
dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta;
|
||||||
|
} else {
|
||||||
|
const int64_t i0 = ic;
|
||||||
|
|
||||||
|
const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
|
||||||
|
float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
||||||
|
|
||||||
|
dst_data[0] = src[0];
|
||||||
|
dst_data[1] = src[1];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -11703,10 +11726,13 @@ static void ggml_compute_forward_rope_f16(
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// TODO: this might be wrong for ne0 != n_dims - need double check
|
// TODO: this might be wrong for ne0 != n_dims - need double check
|
||||||
// ref: https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt_neox/modeling_gpt_neox.py#LL251C1-L294C28
|
// it seems we have to rope just the first n_dims elements and do nothing with the rest
|
||||||
|
// ref: https://github.com/ml-explore/mlx/blob/dc2edc762c797e3b8de50b1dad4dc0a131691033/benchmarks/python/llama_jax_bench.py#L11-L26
|
||||||
theta_base *= freq_scale;
|
theta_base *= freq_scale;
|
||||||
for (int64_t ib = 0; ib < ne0/n_dims; ++ib) {
|
for (int64_t ic = 0; ic < ne0; ic += 2) {
|
||||||
for (int64_t ic = 0; ic < n_dims; ic += 2) {
|
if (ic < n_dims) {
|
||||||
|
const int64_t ib = 0;
|
||||||
|
|
||||||
// simplified from `(ib * n_dims + ic) * inv_ndims`
|
// simplified from `(ib * n_dims + ic) * inv_ndims`
|
||||||
float cur_rot = inv_ndims * ic - ib;
|
float cur_rot = inv_ndims * ic - ib;
|
||||||
|
|
||||||
|
@ -11729,6 +11755,14 @@ static void ggml_compute_forward_rope_f16(
|
||||||
|
|
||||||
dst_data[0] = GGML_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
|
dst_data[0] = GGML_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
|
||||||
dst_data[n_dims/2] = GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
|
dst_data[n_dims/2] = GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
|
||||||
|
} else {
|
||||||
|
const int64_t i0 = ic;
|
||||||
|
|
||||||
|
const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
|
||||||
|
ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
||||||
|
|
||||||
|
dst_data[0] = src[0];
|
||||||
|
dst_data[1] = src[1];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
12
ggml.h
12
ggml.h
|
@ -343,6 +343,12 @@ extern "C" {
|
||||||
GGML_TYPE_COUNT,
|
GGML_TYPE_COUNT,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// precision
|
||||||
|
enum ggml_prec {
|
||||||
|
GGML_PREC_DEFAULT,
|
||||||
|
GGML_PREC_F32,
|
||||||
|
};
|
||||||
|
|
||||||
enum ggml_backend_type {
|
enum ggml_backend_type {
|
||||||
GGML_BACKEND_CPU = 0,
|
GGML_BACKEND_CPU = 0,
|
||||||
GGML_BACKEND_GPU = 10,
|
GGML_BACKEND_GPU = 10,
|
||||||
|
@ -1057,6 +1063,12 @@ extern "C" {
|
||||||
struct ggml_tensor * a,
|
struct ggml_tensor * a,
|
||||||
struct ggml_tensor * b);
|
struct ggml_tensor * b);
|
||||||
|
|
||||||
|
// change the precision of a matrix multiplication
|
||||||
|
// set to GGML_PREC_F32 for higher precision (useful for phi-2)
|
||||||
|
GGML_API void ggml_mul_mat_set_prec(
|
||||||
|
struct ggml_tensor * a,
|
||||||
|
enum ggml_prec prec);
|
||||||
|
|
||||||
// indirect matrix multiplication
|
// indirect matrix multiplication
|
||||||
// ggml_mul_mat_id(ctx, as, ids, id, b) ~= ggml_mul_mat(as[ids[id]], b)
|
// ggml_mul_mat_id(ctx, as, ids, id, b) ~= ggml_mul_mat(as[ids[id]], b)
|
||||||
GGML_API struct ggml_tensor * ggml_mul_mat_id(
|
GGML_API struct ggml_tensor * ggml_mul_mat_id(
|
||||||
|
|
|
@ -95,6 +95,7 @@ class MODEL_ARCH(IntEnum):
|
||||||
BLOOM = auto()
|
BLOOM = auto()
|
||||||
STABLELM = auto()
|
STABLELM = auto()
|
||||||
QWEN = auto()
|
QWEN = auto()
|
||||||
|
PHI2 = auto()
|
||||||
|
|
||||||
|
|
||||||
class MODEL_TENSOR(IntEnum):
|
class MODEL_TENSOR(IntEnum):
|
||||||
|
@ -140,6 +141,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
|
||||||
MODEL_ARCH.BLOOM: "bloom",
|
MODEL_ARCH.BLOOM: "bloom",
|
||||||
MODEL_ARCH.STABLELM: "stablelm",
|
MODEL_ARCH.STABLELM: "stablelm",
|
||||||
MODEL_ARCH.QWEN: "qwen",
|
MODEL_ARCH.QWEN: "qwen",
|
||||||
|
MODEL_ARCH.PHI2: "phi2",
|
||||||
}
|
}
|
||||||
|
|
||||||
TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
|
TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
|
||||||
|
@ -350,6 +352,17 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
|
||||||
MODEL_ARCH.GPT2: [
|
MODEL_ARCH.GPT2: [
|
||||||
# TODO
|
# TODO
|
||||||
],
|
],
|
||||||
|
MODEL_ARCH.PHI2: [
|
||||||
|
MODEL_TENSOR.TOKEN_EMBD,
|
||||||
|
MODEL_TENSOR.OUTPUT_NORM,
|
||||||
|
MODEL_TENSOR.OUTPUT,
|
||||||
|
MODEL_TENSOR.ATTN_NORM,
|
||||||
|
MODEL_TENSOR.ATTN_QKV,
|
||||||
|
MODEL_TENSOR.ATTN_OUT,
|
||||||
|
MODEL_TENSOR.FFN_NORM,
|
||||||
|
MODEL_TENSOR.FFN_DOWN,
|
||||||
|
MODEL_TENSOR.FFN_UP,
|
||||||
|
]
|
||||||
# TODO
|
# TODO
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -17,6 +17,7 @@ class TensorNameMap:
|
||||||
"tok_embeddings", # llama-pth
|
"tok_embeddings", # llama-pth
|
||||||
"embeddings.word_embeddings", # bert
|
"embeddings.word_embeddings", # bert
|
||||||
"language_model.embedding.word_embeddings", # persimmon
|
"language_model.embedding.word_embeddings", # persimmon
|
||||||
|
"transformer.embd.wte", # phi2
|
||||||
),
|
),
|
||||||
|
|
||||||
# Token type embeddings
|
# Token type embeddings
|
||||||
|
@ -41,6 +42,7 @@ class TensorNameMap:
|
||||||
"lm_head", # gpt2 mpt falcon llama-hf baichuan qwen
|
"lm_head", # gpt2 mpt falcon llama-hf baichuan qwen
|
||||||
"output", # llama-pth bloom
|
"output", # llama-pth bloom
|
||||||
"word_embeddings_for_head", # persimmon
|
"word_embeddings_for_head", # persimmon
|
||||||
|
"lm_head.linear", # phi2
|
||||||
),
|
),
|
||||||
|
|
||||||
# Output norm
|
# Output norm
|
||||||
|
@ -53,6 +55,7 @@ class TensorNameMap:
|
||||||
"transformer.norm_f", # mpt
|
"transformer.norm_f", # mpt
|
||||||
"ln_f", # refact bloom qwen
|
"ln_f", # refact bloom qwen
|
||||||
"language_model.encoder.final_layernorm", # persimmon
|
"language_model.encoder.final_layernorm", # persimmon
|
||||||
|
"lm_head.ln", # phi2
|
||||||
),
|
),
|
||||||
|
|
||||||
# Rope frequencies
|
# Rope frequencies
|
||||||
|
@ -75,6 +78,7 @@ class TensorNameMap:
|
||||||
"encoder.layer.{bid}.attention.output.LayerNorm", # bert
|
"encoder.layer.{bid}.attention.output.LayerNorm", # bert
|
||||||
"language_model.encoder.layers.{bid}.input_layernorm", # persimmon
|
"language_model.encoder.layers.{bid}.input_layernorm", # persimmon
|
||||||
"model.layers.{bid}.ln1", # yi
|
"model.layers.{bid}.ln1", # yi
|
||||||
|
"transformer.h.{bid}.ln", # phi2
|
||||||
),
|
),
|
||||||
|
|
||||||
# Attention norm 2
|
# Attention norm 2
|
||||||
|
@ -90,6 +94,7 @@ class TensorNameMap:
|
||||||
"transformer.h.{bid}.self_attention.query_key_value", # falcon
|
"transformer.h.{bid}.self_attention.query_key_value", # falcon
|
||||||
"h.{bid}.self_attention.query_key_value", # bloom
|
"h.{bid}.self_attention.query_key_value", # bloom
|
||||||
"language_model.encoder.layers.{bid}.self_attention.query_key_value", # persimmon
|
"language_model.encoder.layers.{bid}.self_attention.query_key_value", # persimmon
|
||||||
|
"transformer.h.{bid}.mixer.Wqkv", # phi2
|
||||||
),
|
),
|
||||||
|
|
||||||
# Attention query
|
# Attention query
|
||||||
|
@ -128,6 +133,7 @@ class TensorNameMap:
|
||||||
"encoder.layer.{bid}.attention.output.dense", # bert
|
"encoder.layer.{bid}.attention.output.dense", # bert
|
||||||
"transformer.h.{bid}.attn.out_proj", # gpt-j
|
"transformer.h.{bid}.attn.out_proj", # gpt-j
|
||||||
"language_model.encoder.layers.{bid}.self_attention.dense", # persimmon
|
"language_model.encoder.layers.{bid}.self_attention.dense", # persimmon
|
||||||
|
"transformer.h.{bid}.mixer.out_proj", # phi2
|
||||||
),
|
),
|
||||||
|
|
||||||
# Rotary embeddings
|
# Rotary embeddings
|
||||||
|
@ -167,6 +173,7 @@ class TensorNameMap:
|
||||||
"transformer.h.{bid}.mlp.fc_in", # gpt-j
|
"transformer.h.{bid}.mlp.fc_in", # gpt-j
|
||||||
"language_model.encoder.layers.{bid}.mlp.dense_h_to_4h", # persimmon
|
"language_model.encoder.layers.{bid}.mlp.dense_h_to_4h", # persimmon
|
||||||
"transformer.h.{bid}.mlp.w1", # qwen
|
"transformer.h.{bid}.mlp.w1", # qwen
|
||||||
|
"transformer.h.{bid}.mlp.fc1", # phi2
|
||||||
),
|
),
|
||||||
|
|
||||||
MODEL_TENSOR.FFN_UP_EXP: (
|
MODEL_TENSOR.FFN_UP_EXP: (
|
||||||
|
@ -198,6 +205,7 @@ class TensorNameMap:
|
||||||
"encoder.layer.{bid}.output.dense", # bert
|
"encoder.layer.{bid}.output.dense", # bert
|
||||||
"transformer.h.{bid}.mlp.fc_out", # gpt-j
|
"transformer.h.{bid}.mlp.fc_out", # gpt-j
|
||||||
"language_model.encoder.layers.{bid}.mlp.dense_4h_to_h", # persimmon
|
"language_model.encoder.layers.{bid}.mlp.dense_4h_to_h", # persimmon
|
||||||
|
"transformer.h.{bid}.mlp.fc2", # phi2
|
||||||
),
|
),
|
||||||
|
|
||||||
MODEL_TENSOR.FFN_DOWN_EXP: (
|
MODEL_TENSOR.FFN_DOWN_EXP: (
|
||||||
|
|
287
llama.cpp
287
llama.cpp
|
@ -195,6 +195,7 @@ enum llm_arch {
|
||||||
LLM_ARCH_BLOOM,
|
LLM_ARCH_BLOOM,
|
||||||
LLM_ARCH_STABLELM,
|
LLM_ARCH_STABLELM,
|
||||||
LLM_ARCH_QWEN,
|
LLM_ARCH_QWEN,
|
||||||
|
LLM_ARCH_PHI2,
|
||||||
LLM_ARCH_UNKNOWN,
|
LLM_ARCH_UNKNOWN,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -212,6 +213,7 @@ static std::map<llm_arch, std::string> LLM_ARCH_NAMES = {
|
||||||
{ LLM_ARCH_BLOOM, "bloom" },
|
{ LLM_ARCH_BLOOM, "bloom" },
|
||||||
{ LLM_ARCH_STABLELM, "stablelm" },
|
{ LLM_ARCH_STABLELM, "stablelm" },
|
||||||
{ LLM_ARCH_QWEN, "qwen" },
|
{ LLM_ARCH_QWEN, "qwen" },
|
||||||
|
{ LLM_ARCH_PHI2, "phi2" },
|
||||||
};
|
};
|
||||||
|
|
||||||
enum llm_kv {
|
enum llm_kv {
|
||||||
|
@ -550,6 +552,19 @@ static std::map<llm_arch, std::map<llm_tensor, std::string>> LLM_TENSOR_NAMES =
|
||||||
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
LLM_ARCH_PHI2,
|
||||||
|
{
|
||||||
|
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
|
||||||
|
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
|
||||||
|
{ LLM_TENSOR_OUTPUT, "output" },
|
||||||
|
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
|
||||||
|
{ LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" },
|
||||||
|
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
|
||||||
|
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
|
||||||
|
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
||||||
|
},
|
||||||
|
},
|
||||||
|
|
||||||
{
|
{
|
||||||
LLM_ARCH_UNKNOWN,
|
LLM_ARCH_UNKNOWN,
|
||||||
|
@ -1359,6 +1374,7 @@ struct llama_model {
|
||||||
struct ggml_tensor * output_norm;
|
struct ggml_tensor * output_norm;
|
||||||
struct ggml_tensor * output_norm_b;
|
struct ggml_tensor * output_norm_b;
|
||||||
struct ggml_tensor * output;
|
struct ggml_tensor * output;
|
||||||
|
struct ggml_tensor * output_b;
|
||||||
|
|
||||||
std::vector<llama_layer> layers;
|
std::vector<llama_layer> layers;
|
||||||
|
|
||||||
|
@ -1873,7 +1889,7 @@ namespace GGUFMeta {
|
||||||
target = override->bool_value;
|
target = override->bool_value;
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
return true;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
template<typename OT>
|
template<typename OT>
|
||||||
|
@ -2573,6 +2589,15 @@ static void llm_load_hparams(
|
||||||
default: model.type = e_model::MODEL_UNKNOWN;
|
default: model.type = e_model::MODEL_UNKNOWN;
|
||||||
}
|
}
|
||||||
} break;
|
} break;
|
||||||
|
case LLM_ARCH_PHI2:
|
||||||
|
{
|
||||||
|
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
|
||||||
|
|
||||||
|
switch (hparams.n_layer) {
|
||||||
|
case 32: model.type = e_model::MODEL_3B; break;
|
||||||
|
default: model.type = e_model::MODEL_UNKNOWN;
|
||||||
|
}
|
||||||
|
} break;
|
||||||
|
|
||||||
default: (void)0;
|
default: (void)0;
|
||||||
}
|
}
|
||||||
|
@ -3419,7 +3444,57 @@ static void llm_load_tensors(
|
||||||
layer.ffn_up = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, backend_split);
|
layer.ffn_up = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, backend_split);
|
||||||
}
|
}
|
||||||
} break;
|
} break;
|
||||||
|
case LLM_ARCH_PHI2:
|
||||||
|
{
|
||||||
|
model.tok_embd = ml.create_tensor(ctx, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, GGML_BACKEND_CPU);
|
||||||
|
|
||||||
|
// output
|
||||||
|
{
|
||||||
|
ggml_backend_type backend_norm;
|
||||||
|
ggml_backend_type backend_output;
|
||||||
|
|
||||||
|
if (n_gpu_layers > int(n_layer)) {
|
||||||
|
backend_norm = llama_backend_offload;
|
||||||
|
backend_output = llama_backend_offload;
|
||||||
|
} else {
|
||||||
|
backend_norm = GGML_BACKEND_CPU;
|
||||||
|
backend_output = GGML_BACKEND_CPU;
|
||||||
|
}
|
||||||
|
|
||||||
|
model.output_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, backend_norm);
|
||||||
|
model.output_norm_b = ml.create_tensor(ctx, tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, backend_norm);
|
||||||
|
model.output = ml.create_tensor(ctx, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, backend_output);
|
||||||
|
model.output_b = ml.create_tensor(ctx, tn(LLM_TENSOR_OUTPUT, "bias"), {n_vocab}, backend_output);
|
||||||
|
}
|
||||||
|
|
||||||
|
const uint32_t n_ff = hparams.n_ff;
|
||||||
|
|
||||||
|
const int i_gpu_start = n_layer - n_gpu_layers;
|
||||||
|
|
||||||
|
model.layers.resize(n_layer);
|
||||||
|
|
||||||
|
for (uint32_t i = 0; i < n_layer; ++i) {
|
||||||
|
const ggml_backend_type backend = int(i) < i_gpu_start ? GGML_BACKEND_CPU : llama_backend_offload; // NOLINT
|
||||||
|
const ggml_backend_type backend_split = int(i) < i_gpu_start ? GGML_BACKEND_CPU : llama_backend_offload_split; // NOLINT
|
||||||
|
|
||||||
|
auto & layer = model.layers[i];
|
||||||
|
|
||||||
|
layer.attn_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, backend);
|
||||||
|
layer.attn_norm_b = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, backend);
|
||||||
|
|
||||||
|
layer.wqkv = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, backend_split);
|
||||||
|
layer.bqkv = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, backend);
|
||||||
|
|
||||||
|
layer.wo = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, backend_split);
|
||||||
|
layer.bo = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, backend);
|
||||||
|
|
||||||
|
layer.ffn_down = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, backend_split);
|
||||||
|
layer.ffn_down_b = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, backend);
|
||||||
|
|
||||||
|
layer.ffn_up = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, backend_split);
|
||||||
|
layer.ffn_up_b = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, backend);
|
||||||
|
}
|
||||||
|
} break;
|
||||||
default:
|
default:
|
||||||
throw std::runtime_error("unknown architecture");
|
throw std::runtime_error("unknown architecture");
|
||||||
}
|
}
|
||||||
|
@ -3824,6 +3899,7 @@ static struct ggml_tensor * llm_build_ffn(
|
||||||
// if max_alibi_bias > 0 then apply ALiBi
|
// if max_alibi_bias > 0 then apply ALiBi
|
||||||
static struct ggml_tensor * llm_build_kqv(
|
static struct ggml_tensor * llm_build_kqv(
|
||||||
struct ggml_context * ctx,
|
struct ggml_context * ctx,
|
||||||
|
const llama_model & model,
|
||||||
const llama_hparams & hparams,
|
const llama_hparams & hparams,
|
||||||
const llama_kv_cache & kv,
|
const llama_kv_cache & kv,
|
||||||
struct ggml_tensor * wo,
|
struct ggml_tensor * wo,
|
||||||
|
@ -3835,6 +3911,7 @@ static struct ggml_tensor * llm_build_kqv(
|
||||||
int32_t n_tokens,
|
int32_t n_tokens,
|
||||||
int32_t n_kv,
|
int32_t n_kv,
|
||||||
float max_alibi_bias,
|
float max_alibi_bias,
|
||||||
|
float scale,
|
||||||
const llm_build_cb & cb,
|
const llm_build_cb & cb,
|
||||||
int il) {
|
int il) {
|
||||||
const int64_t n_embd = hparams.n_embd;
|
const int64_t n_embd = hparams.n_embd;
|
||||||
|
@ -3857,6 +3934,12 @@ static struct ggml_tensor * llm_build_kqv(
|
||||||
struct ggml_tensor * kq = ggml_mul_mat(ctx, k, q);
|
struct ggml_tensor * kq = ggml_mul_mat(ctx, k, q);
|
||||||
cb(kq, "kq", il);
|
cb(kq, "kq", il);
|
||||||
|
|
||||||
|
if (model.arch == LLM_ARCH_PHI2) {
|
||||||
|
// for this arch, we need to perform the KQ multiplication with F32 precision, otherwise we get NaNs
|
||||||
|
// ref: https://github.com/ggerganov/llama.cpp/pull/4490#issuecomment-1859055847
|
||||||
|
ggml_mul_mat_set_prec(kq, GGML_PREC_F32);
|
||||||
|
}
|
||||||
|
|
||||||
if (max_alibi_bias > 0.0f) {
|
if (max_alibi_bias > 0.0f) {
|
||||||
// temporary branch until we figure out how to handle ggml_alibi through ggml_add
|
// temporary branch until we figure out how to handle ggml_alibi through ggml_add
|
||||||
kq = ggml_scale(ctx, kq, kq_scale);
|
kq = ggml_scale(ctx, kq, kq_scale);
|
||||||
|
@ -3876,7 +3959,7 @@ static struct ggml_tensor * llm_build_kqv(
|
||||||
kq = ggml_soft_max(ctx, kq);
|
kq = ggml_soft_max(ctx, kq);
|
||||||
cb(kq, "kq_soft_max", il);
|
cb(kq, "kq_soft_max", il);
|
||||||
} else {
|
} else {
|
||||||
kq = ggml_soft_max_ext(ctx, kq, kq_mask, 1.0f/sqrtf(float(n_embd_head)));
|
kq = ggml_soft_max_ext(ctx, kq, kq_mask, scale);
|
||||||
cb(kq, "kq_soft_max_ext", il);
|
cb(kq, "kq_soft_max_ext", il);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -4083,9 +4166,9 @@ struct llm_build_context {
|
||||||
|
|
||||||
llm_build_kv_store(ctx0, hparams, kv_self, gf, Kcur, Vcur, n_ctx, n_tokens, kv_head, cb, il);
|
llm_build_kv_store(ctx0, hparams, kv_self, gf, Kcur, Vcur, n_ctx, n_tokens, kv_head, cb, il);
|
||||||
|
|
||||||
cur = llm_build_kqv(ctx0, hparams, kv_self,
|
cur = llm_build_kqv(ctx0, model, hparams, kv_self,
|
||||||
model.layers[il].wo, model.layers[il].bo,
|
model.layers[il].wo, model.layers[il].bo,
|
||||||
Qcur, KQ_scale, KQ_mask, n_ctx, n_tokens, n_kv, -1.0f, cb, il);
|
Qcur, KQ_scale, KQ_mask, n_ctx, n_tokens, n_kv, -1.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il);
|
||||||
cb(cur, "kqv_out", il);
|
cb(cur, "kqv_out", il);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -4266,9 +4349,9 @@ struct llm_build_context {
|
||||||
// apply ALiBi for 13B model
|
// apply ALiBi for 13B model
|
||||||
const float max_alibi_bias = model.type == MODEL_13B ? 8.0f : -1.0f;
|
const float max_alibi_bias = model.type == MODEL_13B ? 8.0f : -1.0f;
|
||||||
|
|
||||||
cur = llm_build_kqv(ctx0, hparams, kv_self,
|
cur = llm_build_kqv(ctx0, model, hparams, kv_self,
|
||||||
model.layers[il].wo, NULL,
|
model.layers[il].wo, NULL,
|
||||||
Qcur, KQ_scale, KQ_mask, n_ctx, n_tokens, n_kv, max_alibi_bias, cb, il);
|
Qcur, KQ_scale, KQ_mask, n_ctx, n_tokens, n_kv, max_alibi_bias, 1.0f/sqrtf(float(n_embd_head)), cb, il);
|
||||||
cb(cur, "kqv_out", il);
|
cb(cur, "kqv_out", il);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -4390,9 +4473,9 @@ struct llm_build_context {
|
||||||
|
|
||||||
llm_build_kv_store(ctx0, hparams, kv_self, gf, Kcur, Vcur, n_ctx, n_tokens, kv_head, cb, il);
|
llm_build_kv_store(ctx0, hparams, kv_self, gf, Kcur, Vcur, n_ctx, n_tokens, kv_head, cb, il);
|
||||||
|
|
||||||
cur = llm_build_kqv(ctx0, hparams, kv_self,
|
cur = llm_build_kqv(ctx0, model, hparams, kv_self,
|
||||||
model.layers[il].wo, NULL,
|
model.layers[il].wo, NULL,
|
||||||
Qcur, KQ_scale, KQ_mask, n_ctx, n_tokens, n_kv, -1.0f, cb, il);
|
Qcur, KQ_scale, KQ_mask, n_ctx, n_tokens, n_kv, -1.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il);
|
||||||
cb(cur, "kqv_out", il);
|
cb(cur, "kqv_out", il);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -4490,9 +4573,9 @@ struct llm_build_context {
|
||||||
|
|
||||||
llm_build_kv_store(ctx0, hparams, kv_self, gf, Kcur, Vcur, n_ctx, n_tokens, kv_head, cb, il);
|
llm_build_kv_store(ctx0, hparams, kv_self, gf, Kcur, Vcur, n_ctx, n_tokens, kv_head, cb, il);
|
||||||
|
|
||||||
cur = llm_build_kqv(ctx0, hparams, kv_self,
|
cur = llm_build_kqv(ctx0, model, hparams, kv_self,
|
||||||
model.layers[il].wo, model.layers[il].bo,
|
model.layers[il].wo, model.layers[il].bo,
|
||||||
Qcur, KQ_scale, KQ_mask, n_ctx, n_tokens, n_kv, -1.0f, cb, il);
|
Qcur, KQ_scale, KQ_mask, n_ctx, n_tokens, n_kv, -1.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il);
|
||||||
cb(cur, "kqv_out", il);
|
cb(cur, "kqv_out", il);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -4699,9 +4782,9 @@ struct llm_build_context {
|
||||||
llm_build_kv_store(ctx0, hparams, kv_self, gf, Kcur, Vcur, n_ctx, n_tokens, kv_head, cb, il);
|
llm_build_kv_store(ctx0, hparams, kv_self, gf, Kcur, Vcur, n_ctx, n_tokens, kv_head, cb, il);
|
||||||
|
|
||||||
// TODO: not tested, could be broken
|
// TODO: not tested, could be broken
|
||||||
cur = llm_build_kqv(ctx0, hparams, kv_self,
|
cur = llm_build_kqv(ctx0, model, hparams, kv_self,
|
||||||
model.layers[il].wo, model.layers[il].bo,
|
model.layers[il].wo, model.layers[il].bo,
|
||||||
Q, KQ_scale, KQ_mask, n_ctx, n_tokens, n_kv, -1.0f, cb, il);
|
Q, KQ_scale, KQ_mask, n_ctx, n_tokens, n_kv, -1.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il);
|
||||||
cb(cur, "kqv_out", il);
|
cb(cur, "kqv_out", il);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -4790,9 +4873,9 @@ struct llm_build_context {
|
||||||
|
|
||||||
llm_build_kv_store(ctx0, hparams, kv_self, gf, Kcur, Vcur, n_ctx, n_tokens, kv_head, cb, il);
|
llm_build_kv_store(ctx0, hparams, kv_self, gf, Kcur, Vcur, n_ctx, n_tokens, kv_head, cb, il);
|
||||||
|
|
||||||
cur = llm_build_kqv(ctx0, hparams, kv_self,
|
cur = llm_build_kqv(ctx0, model, hparams, kv_self,
|
||||||
model.layers[il].wo, NULL,
|
model.layers[il].wo, NULL,
|
||||||
Qcur, KQ_scale, KQ_mask, n_ctx, n_tokens, n_kv, 8.0f, cb, il);
|
Qcur, KQ_scale, KQ_mask, n_ctx, n_tokens, n_kv, 8.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il);
|
||||||
cb(cur, "kqv_out", il);
|
cb(cur, "kqv_out", il);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -4887,9 +4970,9 @@ struct llm_build_context {
|
||||||
|
|
||||||
llm_build_kv_store(ctx0, hparams, kv_self, gf, Kcur, Vcur, n_ctx, n_tokens, kv_head, cb, il);
|
llm_build_kv_store(ctx0, hparams, kv_self, gf, Kcur, Vcur, n_ctx, n_tokens, kv_head, cb, il);
|
||||||
|
|
||||||
cur = llm_build_kqv(ctx0, hparams, kv_self,
|
cur = llm_build_kqv(ctx0, model, hparams, kv_self,
|
||||||
model.layers[il].wo, model.layers[il].bo,
|
model.layers[il].wo, model.layers[il].bo,
|
||||||
Qcur, KQ_scale, KQ_mask, n_ctx, n_tokens, n_kv, 8.0f, cb, il);
|
Qcur, KQ_scale, KQ_mask, n_ctx, n_tokens, n_kv, 8.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il);
|
||||||
cb(cur, "kqv_out", il);
|
cb(cur, "kqv_out", il);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -4981,9 +5064,9 @@ struct llm_build_context {
|
||||||
|
|
||||||
llm_build_kv_store(ctx0, hparams, kv_self, gf, Kcur, Vcur, n_ctx, n_tokens, kv_head, cb, il);
|
llm_build_kv_store(ctx0, hparams, kv_self, gf, Kcur, Vcur, n_ctx, n_tokens, kv_head, cb, il);
|
||||||
|
|
||||||
cur = llm_build_kqv(ctx0, hparams, kv_self,
|
cur = llm_build_kqv(ctx0, model, hparams, kv_self,
|
||||||
model.layers[il].wo, NULL,
|
model.layers[il].wo, NULL,
|
||||||
Qcur, KQ_scale, KQ_mask, n_ctx, n_tokens, n_kv, hparams.f_max_alibi_bias, cb, il);
|
Qcur, KQ_scale, KQ_mask, n_ctx, n_tokens, n_kv, hparams.f_max_alibi_bias, 1.0f/sqrtf(float(n_embd_head)), cb, il);
|
||||||
cb(cur, "kqv_out", il);
|
cb(cur, "kqv_out", il);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -5094,9 +5177,9 @@ struct llm_build_context {
|
||||||
|
|
||||||
llm_build_kv_store(ctx0, hparams, kv_self, gf, Kcur, Vcur, n_ctx, n_tokens, kv_head, cb, il);
|
llm_build_kv_store(ctx0, hparams, kv_self, gf, Kcur, Vcur, n_ctx, n_tokens, kv_head, cb, il);
|
||||||
|
|
||||||
cur = llm_build_kqv(ctx0, hparams, kv_self,
|
cur = llm_build_kqv(ctx0, model, hparams, kv_self,
|
||||||
model.layers[il].wo, NULL,
|
model.layers[il].wo, NULL,
|
||||||
Qcur, KQ_scale, KQ_mask, n_ctx, n_tokens, n_kv, -1.0f, cb, il);
|
Qcur, KQ_scale, KQ_mask, n_ctx, n_tokens, n_kv, -1.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il);
|
||||||
cb(cur, "kqv_out", il);
|
cb(cur, "kqv_out", il);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -5153,15 +5236,15 @@ struct llm_build_context {
|
||||||
cb(inpL, "inp_embd", -1);
|
cb(inpL, "inp_embd", -1);
|
||||||
|
|
||||||
// inp_pos - contains the positions
|
// inp_pos - contains the positions
|
||||||
struct ggml_tensor * inp_pos= ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens);
|
struct ggml_tensor * inp_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens);
|
||||||
cb(inp_pos, "inp_pos", -1);
|
cb(inp_pos, "inp_pos", -1);
|
||||||
|
|
||||||
// KQ_scale
|
// KQ_scale
|
||||||
struct ggml_tensor * KQ_scale= ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1);
|
struct ggml_tensor * KQ_scale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1);
|
||||||
cb(KQ_scale, "KQ_scale", -1);
|
cb(KQ_scale, "KQ_scale", -1);
|
||||||
|
|
||||||
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
|
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
|
||||||
struct ggml_tensor * KQ_mask= ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, n_tokens, 1);
|
struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, n_tokens, 1);
|
||||||
cb(KQ_mask, "KQ_mask", -1);
|
cb(KQ_mask, "KQ_mask", -1);
|
||||||
|
|
||||||
// shift the entire K-cache if needed
|
// shift the entire K-cache if needed
|
||||||
|
@ -5211,9 +5294,9 @@ struct llm_build_context {
|
||||||
|
|
||||||
llm_build_kv_store(ctx0, hparams, kv_self, gf, Kcur, Vcur, n_ctx, n_tokens, kv_head, cb, il);
|
llm_build_kv_store(ctx0, hparams, kv_self, gf, Kcur, Vcur, n_ctx, n_tokens, kv_head, cb, il);
|
||||||
|
|
||||||
cur = llm_build_kqv(ctx0, hparams, kv_self,
|
cur = llm_build_kqv(ctx0, model, hparams, kv_self,
|
||||||
model.layers[il].wo, NULL,
|
model.layers[il].wo, NULL,
|
||||||
Qcur, KQ_scale, KQ_mask, n_ctx, n_tokens, n_kv, -1.0f, cb, il);
|
Qcur, KQ_scale, KQ_mask, n_ctx, n_tokens, n_kv, -1.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il);
|
||||||
cb(cur, "kqv_out", il);
|
cb(cur, "kqv_out", il);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -5255,6 +5338,122 @@ struct llm_build_context {
|
||||||
|
|
||||||
ggml_build_forward_expand(gf, cur);
|
ggml_build_forward_expand(gf, cur);
|
||||||
|
|
||||||
|
return gf;
|
||||||
|
}
|
||||||
|
struct ggml_cgraph * build_phi2() {
|
||||||
|
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
|
||||||
|
|
||||||
|
struct ggml_tensor * cur;
|
||||||
|
struct ggml_tensor * attn_norm_output;
|
||||||
|
struct ggml_tensor * ffn_output;
|
||||||
|
struct ggml_tensor * inpL;
|
||||||
|
|
||||||
|
inpL = llm_build_inp_embd(ctx0, hparams, batch, model.tok_embd, cb);
|
||||||
|
cb(inpL, "inp_embd", -1);
|
||||||
|
|
||||||
|
// inp_pos - contains the positions
|
||||||
|
struct ggml_tensor * inp_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens);
|
||||||
|
cb(inp_pos, "inp_pos", -1);
|
||||||
|
|
||||||
|
// Q_scale
|
||||||
|
struct ggml_tensor * Q_scale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1);
|
||||||
|
cb(Q_scale, "Q_scale", -1);
|
||||||
|
|
||||||
|
// KQ_scale
|
||||||
|
struct ggml_tensor * KQ_scale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1);
|
||||||
|
cb(KQ_scale, "KQ_scale", -1);
|
||||||
|
|
||||||
|
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
|
||||||
|
struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, n_tokens, 1);
|
||||||
|
cb(KQ_mask, "KQ_mask", -1);
|
||||||
|
|
||||||
|
// shift the entire K-cache if needed
|
||||||
|
if (do_rope_shift) {
|
||||||
|
llm_build_k_shift(ctx0, hparams, cparams, kv_self, gf, LLM_ROPE_NEOX, n_ctx, n_embd_head, freq_base, freq_scale, cb);
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int il = 0; il < n_layer; ++il) {
|
||||||
|
attn_norm_output = llm_build_norm(ctx0, inpL, hparams,
|
||||||
|
model.layers[il].attn_norm,
|
||||||
|
model.layers[il].attn_norm_b,
|
||||||
|
LLM_NORM, cb, il);
|
||||||
|
cb(attn_norm_output, "attn_norm", il);
|
||||||
|
|
||||||
|
// self-attention
|
||||||
|
{
|
||||||
|
cur = ggml_mul_mat(ctx0, model.layers[il].wqkv, attn_norm_output);
|
||||||
|
cb(cur, "wqkv", il);
|
||||||
|
|
||||||
|
cur = ggml_add(ctx0, cur, model.layers[il].bqkv);
|
||||||
|
cb(cur, "bqkv", il);
|
||||||
|
|
||||||
|
struct ggml_tensor * Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd)));
|
||||||
|
struct ggml_tensor * Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd)));
|
||||||
|
struct ggml_tensor * Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)));
|
||||||
|
|
||||||
|
cb(Qcur, "Qcur", il);
|
||||||
|
cb(Kcur, "Kcur", il);
|
||||||
|
cb(Vcur, "Vcur", il);
|
||||||
|
|
||||||
|
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
|
||||||
|
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
|
||||||
|
|
||||||
|
Qcur = ggml_rope_custom(
|
||||||
|
ctx0, Qcur, inp_pos, hparams.n_rot, 2, 0, n_orig_ctx,
|
||||||
|
freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
|
||||||
|
);
|
||||||
|
cb(Qcur, "Qcur", il);
|
||||||
|
|
||||||
|
Qcur = ggml_scale(ctx0, Qcur, Q_scale);
|
||||||
|
cb(Qcur, "Qcur", il);
|
||||||
|
|
||||||
|
Kcur = ggml_rope_custom(
|
||||||
|
ctx0, Kcur, inp_pos, hparams.n_rot, 2, 0, n_orig_ctx,
|
||||||
|
freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
|
||||||
|
);
|
||||||
|
cb(Kcur, "Kcur", il);
|
||||||
|
|
||||||
|
llm_build_kv_store(ctx0, hparams, kv_self, gf, Kcur, Vcur, n_ctx, n_tokens, kv_head, cb, il);
|
||||||
|
|
||||||
|
cur = llm_build_kqv(ctx0, model, hparams, kv_self,
|
||||||
|
model.layers[il].wo, model.layers[il].bo,
|
||||||
|
Qcur, KQ_scale, KQ_mask, n_ctx, n_tokens, n_kv, -1.0f, 1.0f, cb, il);
|
||||||
|
cb(cur, "kqv_out", il);
|
||||||
|
}
|
||||||
|
|
||||||
|
// FF
|
||||||
|
{
|
||||||
|
ffn_output = llm_build_ffn(ctx0, attn_norm_output,
|
||||||
|
model.layers[il].ffn_up, model.layers[il].ffn_up_b,
|
||||||
|
NULL, NULL,
|
||||||
|
model.layers[il].ffn_down, model.layers[il].ffn_down_b,
|
||||||
|
LLM_FFN_GELU, LLM_FFN_SEQ, cb, il);
|
||||||
|
cb(ffn_output, "ffn_out", il);
|
||||||
|
}
|
||||||
|
|
||||||
|
cur = ggml_add(ctx0, cur, ffn_output);
|
||||||
|
cb(cur, "l_out", il);
|
||||||
|
|
||||||
|
cur = ggml_add(ctx0, cur, inpL);
|
||||||
|
cb(cur, "l_out", il);
|
||||||
|
|
||||||
|
inpL = cur;
|
||||||
|
}
|
||||||
|
|
||||||
|
cur = llm_build_norm(ctx0, inpL, hparams,
|
||||||
|
model.output_norm,
|
||||||
|
model.output_norm_b,
|
||||||
|
LLM_NORM, cb, -1);
|
||||||
|
cb(cur, "result_norm", -1);
|
||||||
|
|
||||||
|
cur = ggml_mul_mat(ctx0, model.output, cur);
|
||||||
|
cb(cur, "result_output_no_bias", -1);
|
||||||
|
|
||||||
|
cur = ggml_add(ctx0, cur, model.output_b);
|
||||||
|
cb(cur, "result_output", -1);
|
||||||
|
|
||||||
|
ggml_build_forward_expand(gf, cur);
|
||||||
|
|
||||||
return gf;
|
return gf;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
@ -5270,7 +5469,7 @@ enum llm_offload_func_e {
|
||||||
OFFLOAD_FUNC_FRC, // force offload
|
OFFLOAD_FUNC_FRC, // force offload
|
||||||
OFFLOAD_FUNC_KQV,
|
OFFLOAD_FUNC_KQV,
|
||||||
OFFLOAD_FUNC_NR,
|
OFFLOAD_FUNC_NR,
|
||||||
OFFLOAD_FUNC_EMB,
|
OFFLOAD_FUNC_EMB, // embeddings
|
||||||
OFFLOAD_FUNC_OUT,
|
OFFLOAD_FUNC_OUT,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -5355,6 +5554,7 @@ static const std::unordered_map<const char *, llm_offload_func_e> k_offload_map
|
||||||
{ "pos_embd", OFFLOAD_FUNC_NR },
|
{ "pos_embd", OFFLOAD_FUNC_NR },
|
||||||
|
|
||||||
{ "inp_pos", OFFLOAD_FUNC_FRC }, // this is often used for KQ ops (e.g. rope)
|
{ "inp_pos", OFFLOAD_FUNC_FRC }, // this is often used for KQ ops (e.g. rope)
|
||||||
|
{ "Q_scale", OFFLOAD_FUNC_FRC },
|
||||||
{ "KQ_scale", OFFLOAD_FUNC_FRC },
|
{ "KQ_scale", OFFLOAD_FUNC_FRC },
|
||||||
{ "KQ_mask", OFFLOAD_FUNC_FRC },
|
{ "KQ_mask", OFFLOAD_FUNC_FRC },
|
||||||
{ "K_shift", OFFLOAD_FUNC_FRC },
|
{ "K_shift", OFFLOAD_FUNC_FRC },
|
||||||
|
@ -5439,6 +5639,7 @@ static const std::unordered_map<const char *, llm_offload_func_e> k_offload_map
|
||||||
{ "l_out", OFFLOAD_FUNC },
|
{ "l_out", OFFLOAD_FUNC },
|
||||||
|
|
||||||
{ "result_norm", OFFLOAD_FUNC_EMB },
|
{ "result_norm", OFFLOAD_FUNC_EMB },
|
||||||
|
{ "result_output_no_bias", OFFLOAD_FUNC_EMB },
|
||||||
{ "result_output", OFFLOAD_FUNC_OUT },
|
{ "result_output", OFFLOAD_FUNC_OUT },
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -5456,6 +5657,7 @@ static struct ggml_cgraph * llama_build_graph(
|
||||||
bool alloc_inp_tokens = false;
|
bool alloc_inp_tokens = false;
|
||||||
bool alloc_inp_embd = false;
|
bool alloc_inp_embd = false;
|
||||||
bool alloc_inp_pos = false;
|
bool alloc_inp_pos = false;
|
||||||
|
bool alloc_inp_Q_scale = false;
|
||||||
bool alloc_inp_KQ_scale = false;
|
bool alloc_inp_KQ_scale = false;
|
||||||
bool alloc_inp_KQ_mask = false;
|
bool alloc_inp_KQ_mask = false;
|
||||||
bool alloc_inp_K_shift = false;
|
bool alloc_inp_K_shift = false;
|
||||||
|
@ -5523,7 +5725,7 @@ static struct ggml_cgraph * llama_build_graph(
|
||||||
alloc_inp_pos = true;
|
alloc_inp_pos = true;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!alloc_inp_KQ_scale && strcmp(name, "KQ_scale") == 0) {
|
if (!alloc_inp_Q_scale && strcmp(name, "Q_scale") == 0) {
|
||||||
ggml_allocr_alloc(lctx.alloc, cur);
|
ggml_allocr_alloc(lctx.alloc, cur);
|
||||||
|
|
||||||
if (!ggml_allocr_is_measure(lctx.alloc)) {
|
if (!ggml_allocr_is_measure(lctx.alloc)) {
|
||||||
|
@ -5531,6 +5733,23 @@ static struct ggml_cgraph * llama_build_graph(
|
||||||
ggml_set_f32(cur, 1.0f/sqrtf(float(n_embd_head)));
|
ggml_set_f32(cur, 1.0f/sqrtf(float(n_embd_head)));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
alloc_inp_Q_scale = true;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!alloc_inp_KQ_scale && strcmp(name, "KQ_scale") == 0) {
|
||||||
|
ggml_allocr_alloc(lctx.alloc, cur);
|
||||||
|
|
||||||
|
if (!ggml_allocr_is_measure(lctx.alloc)) {
|
||||||
|
const int64_t n_embd_head = model.hparams.n_embd_head();
|
||||||
|
if (model.arch == LLM_ARCH_PHI2) {
|
||||||
|
// with phi2, we scale the Q to avoid precision issues
|
||||||
|
// ref: https://github.com/ml-explore/mlx-examples/blob/08e862336ade809bc37d1035f94b359e7d1a5152/phi2/phi2.py#L64-L66
|
||||||
|
ggml_set_f32(cur, 1.0f);
|
||||||
|
} else {
|
||||||
|
ggml_set_f32(cur, 1.0f/sqrtf(float(n_embd_head)));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
alloc_inp_KQ_scale = true;
|
alloc_inp_KQ_scale = true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -5755,6 +5974,10 @@ static struct ggml_cgraph * llama_build_graph(
|
||||||
{
|
{
|
||||||
result = llm.build_qwen();
|
result = llm.build_qwen();
|
||||||
} break;
|
} break;
|
||||||
|
case LLM_ARCH_PHI2:
|
||||||
|
{
|
||||||
|
result = llm.build_phi2();
|
||||||
|
} break;
|
||||||
default:
|
default:
|
||||||
GGML_ASSERT(false);
|
GGML_ASSERT(false);
|
||||||
}
|
}
|
||||||
|
@ -5888,12 +6111,16 @@ static int llama_decode_internal(
|
||||||
|
|
||||||
ggml_allocr_alloc_graph(lctx.alloc, gf);
|
ggml_allocr_alloc_graph(lctx.alloc, gf);
|
||||||
|
|
||||||
|
// the output is always the last tensor in the graph
|
||||||
struct ggml_tensor * res = gf->nodes[gf->n_nodes - 1];
|
struct ggml_tensor * res = gf->nodes[gf->n_nodes - 1];
|
||||||
struct ggml_tensor * embeddings = gf->nodes[gf->n_nodes - 2];
|
|
||||||
|
|
||||||
GGML_ASSERT(strcmp(res->name, "result_output") == 0);
|
GGML_ASSERT(strcmp(res->name, "result_output") == 0);
|
||||||
GGML_ASSERT(strcmp(embeddings->name, "result_norm") == 0);
|
|
||||||
|
|
||||||
|
// the embeddings could be the second to last tensor, or the third to last tensor
|
||||||
|
struct ggml_tensor * embeddings = gf->nodes[gf->n_nodes - 2];
|
||||||
|
if (strcmp(embeddings->name, "result_norm") != 0) {
|
||||||
|
embeddings = gf->nodes[gf->n_nodes - 3];
|
||||||
|
GGML_ASSERT(strcmp(embeddings->name, "result_norm") == 0);
|
||||||
|
}
|
||||||
|
|
||||||
#ifdef GGML_USE_CUBLAS
|
#ifdef GGML_USE_CUBLAS
|
||||||
char * buf_alloc_base = (char *)ggml_backend_buffer_get_base(lctx.buf_alloc);
|
char * buf_alloc_base = (char *)ggml_backend_buffer_get_base(lctx.buf_alloc);
|
||||||
|
|
|
@ -1555,6 +1555,7 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
|
||||||
test_cases.emplace_back(new test_rope(type, { 64, 8, 10, 1}, 64, 2, 512)); // neox (falcon 40B)
|
test_cases.emplace_back(new test_rope(type, { 64, 8, 10, 1}, 64, 2, 512)); // neox (falcon 40B)
|
||||||
test_cases.emplace_back(new test_rope(type, { 64, 128, 10, 1}, 64, 2, 512)); // neox (falcon 40B)
|
test_cases.emplace_back(new test_rope(type, { 64, 128, 10, 1}, 64, 2, 512)); // neox (falcon 40B)
|
||||||
test_cases.emplace_back(new test_rope(type, { 80, 32, 10, 1}, 20, 2, 512)); // neox (stablelm)
|
test_cases.emplace_back(new test_rope(type, { 80, 32, 10, 1}, 20, 2, 512)); // neox (stablelm)
|
||||||
|
test_cases.emplace_back(new test_rope(type, { 80, 32, 10, 1}, 32, 2, 512)); // neox (phi-2)
|
||||||
}
|
}
|
||||||
|
|
||||||
test_cases.emplace_back(new test_alibi());
|
test_cases.emplace_back(new test_alibi());
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue