diff --git a/README-sycl.md b/README-sycl.md index f6dbfd878..2aa465070 100644 --- a/README-sycl.md +++ b/README-sycl.md @@ -229,12 +229,11 @@ source /opt/intel/oneapi/setvars.sh # Build LLAMA with MKL BLAS acceleration for intel GPU mkdir -p build && cd build -# Option 1: Use FP16 for better performance in long-prompt inference -cmake --build .. -DLLAMA_SYCL=ON -DCMAKE_C_COMPILER=icx -DCMAKE_CXX_COMPILER=icpx -DLLAMA_SYCL_F16=ON -# Or without "--build", run "make" next +# Option 1: Use FP16 for better performance in long-prompt inference +#cmake .. -DLLAMA_SYCL=ON -DCMAKE_C_COMPILER=icx -DCMAKE_CXX_COMPILER=icpx -DLLAMA_SYCL_F16=ON # Option 2: Use FP32 by default -cmake --build .. -DLLAMA_SYCL=ON -DCMAKE_C_COMPILER=icx -DCMAKE_CXX_COMPILER=icpx +cmake .. -DLLAMA_SYCL=ON -DCMAKE_C_COMPILER=icx -DCMAKE_CXX_COMPILER=icpx #build all binary cmake --build . --config Release -j -v @@ -252,10 +251,10 @@ export CPLUS_INCLUDE_DIR=/path/to/oneMKL/include:$CPLUS_INCLUDE_DIR mkdir -p build && cd build # Option 1: Use FP16 for better performance in long-prompt inference -cmake --build .. -DLLAMA_SYCL=ON -DLLAMA_SYCL_TARGET=NVIDIA -DCMAKE_C_COMPILER=icx -DCMAKE_CXX_COMPILER=icpx -DLLAMA_SYCL_F16=ON +cmake .. -DLLAMA_SYCL=ON -DLLAMA_SYCL_TARGET=NVIDIA -DCMAKE_C_COMPILER=icx -DCMAKE_CXX_COMPILER=icpx -DLLAMA_SYCL_F16=ON # Option 2: Use FP32 by default -cmake --build .. -DLLAMA_SYCL=ON -DLLAMA_SYCL_TARGET=NVIDIA -DCMAKE_C_COMPILER=icx -DCMAKE_CXX_COMPILER=icpx +cmake .. -DLLAMA_SYCL=ON -DLLAMA_SYCL_TARGET=NVIDIA -DCMAKE_C_COMPILER=icx -DCMAKE_CXX_COMPILER=icpx #build all binary cmake --build . --config Release -j -v diff --git a/README.md b/README.md index e3eae60bb..086618315 100644 --- a/README.md +++ b/README.md @@ -122,6 +122,7 @@ Typically finetunes of the base models below are supported as well. - [x] [Command-R](https://huggingface.co/CohereForAI/c4ai-command-r-v01) - [x] [SEA-LION](https://huggingface.co/models?search=sea-lion) - [x] [GritLM-7B](https://huggingface.co/GritLM/GritLM-7B) + [GritLM-8x7B](https://huggingface.co/GritLM/GritLM-8x7B) +- [x] [OLMo](https://allenai.org/olmo) (instructions for supporting more models: [HOWTO-add-model.md](./docs/HOWTO-add-model.md)) diff --git a/convert-hf-to-gguf.py b/convert-hf-to-gguf.py index c14186abb..358dba8ed 100755 --- a/convert-hf-to-gguf.py +++ b/convert-hf-to-gguf.py @@ -2636,6 +2636,66 @@ class CommandR2Model(Model): self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.NONE) +@Model.register("OlmoForCausalLM") +@Model.register("OLMoForCausalLM") +class OlmoModel(Model): + model_arch = gguf.MODEL_ARCH.OLMO + + def set_gguf_parameters(self): + super().set_gguf_parameters() + self.gguf_writer.add_layer_norm_eps(1e-5) + if "clip_qkv" in self.hparams is not None: + self.gguf_writer.add_clamp_kqv(self.hparams["clip_qkv"]) + + # Same as super class, but permuting q_proj, k_proj + # Copied from: LlamaModel + def write_tensors(self): + block_count = self.hparams.get("n_layers", self.hparams.get("num_hidden_layers", self.hparams.get("n_layer"))) + tensor_map = gguf.get_tensor_name_map(self.model_arch, block_count) + n_head = self.hparams.get("num_attention_heads") + n_kv_head = self.hparams.get("num_key_value_heads") + for name, data_torch in self.get_tensors(): + old_dtype = data_torch.dtype + + # convert any unsupported data types to float32 + if data_torch.dtype not in (torch.float16, torch.float32): + data_torch = data_torch.to(torch.float32) + + data = data_torch.numpy() + + if name.endswith("q_proj.weight"): + data = permute(data, n_head, n_head) + if name.endswith("k_proj.weight"): + data = permute(data, n_head, n_kv_head) + + data = data.squeeze() + + # map tensor names + new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias")) + if new_name is None: + print(f"Can not map tensor {name!r}") + sys.exit() + + n_dims = len(data.shape) + data_dtype = data.dtype + + # if f32 desired, convert any float16 to float32 + if self.ftype == 0 and data_dtype == np.float16: + data = data.astype(np.float32) + + # 1d tensors need to be converted to float32 + if self.ftype == 1 and data_dtype == np.float16 and n_dims == 1: + data = data.astype(np.float32) + + # if f16 desired, convert any float32 2-dim weight tensors to float16 + if self.ftype == 1 and data_dtype == np.float32 and n_dims == 2: + data = data.astype(np.float16) + + print(f"{new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}") + + self.gguf_writer.add_tensor(new_name, data) + + ###### CONVERSION LOGIC ###### diff --git a/examples/imatrix/imatrix.cpp b/examples/imatrix/imatrix.cpp index 73609d3e6..98c0e93e4 100644 --- a/examples/imatrix/imatrix.cpp +++ b/examples/imatrix/imatrix.cpp @@ -44,7 +44,7 @@ private: std::mutex m_mutex; int m_last_call = 0; std::vector m_src1_data; - std::vector m_ids; // the expert ids from ggml_mul_mat_id + std::vector m_ids; // the expert ids from ggml_mul_mat_id // void save_imatrix(const char * file_name) const; void keep_imatrix(int ncall) const; @@ -81,6 +81,7 @@ bool IMatrixCollector::collect_imatrix(struct ggml_tensor * t, bool ask, void * if (ask) { if (t->op == GGML_OP_MUL_MAT_ID) return true; // collect all indirect matrix multiplications if (t->op != GGML_OP_MUL_MAT) return false; + // why are small batches ignored (<16 tokens)? if (src1->ne[1] < 16 || src1->type != GGML_TYPE_F32) return false; if (!(wname.substr(0, 4) == "blk." || (m_params.collect_output_weight && wname == "output.weight"))) return false; return true; @@ -101,14 +102,19 @@ bool IMatrixCollector::collect_imatrix(struct ggml_tensor * t, bool ask, void * // this has been adapted to the new format of storing merged experts in a single 3d tensor // ref: https://github.com/ggerganov/llama.cpp/pull/6387 if (t->op == GGML_OP_MUL_MAT_ID) { - const int idx = ((int32_t *) t->op_params)[0]; + // ids -> [n_experts_used, n_tokens] + // src1 -> [cols, n_expert_used, n_tokens] const ggml_tensor * ids = t->src[2]; const int n_as = src0->ne[2]; + const int n_ids = ids->ne[0]; // the top-k selected expert ids are stored in the ids tensor // for simplicity, always copy ids to host, because it is small - GGML_ASSERT(ids->ne[1] == src1->ne[1]); - m_ids.resize(ggml_nbytes(ids)/sizeof(int)); + // take into account that ids is not contiguous! + + GGML_ASSERT(ids->ne[1] == src1->ne[2]); + + m_ids.resize(ggml_nbytes(ids)); ggml_backend_tensor_get(ids, m_ids.data(), 0, ggml_nbytes(ids)); auto & e = m_stats[wname]; @@ -118,26 +124,35 @@ bool IMatrixCollector::collect_imatrix(struct ggml_tensor * t, bool ask, void * // using the following line, we can correct for that if needed by replacing the line above with: //if (idx == t->src[0]->ne[0] - 1) ++e.ncall; + if (e.values.empty()) { + e.values.resize(src1->ne[0]*n_as, 0); + } + else if (e.values.size() != (size_t)src1->ne[0]*n_as) { + fprintf(stderr, "Oops: inconsistent size for %s (%d vs %d)\n", wname.c_str(), (int)e.values.size(), (int)src1->ne[0]*n_as); + exit(1); //GGML_ASSERT(false); + } + if (m_params.verbosity > 1) { + printf("%s[%d]: %32s, %s, %5d x %5d, %d\n", __func__, m_last_call, wname.c_str(), ggml_op_name(t->op), (int)src1->ne[0], (int)src1->ne[2], (int)src1->type); + } // loop over all possible experts, regardless if they are used or not in the batch for (int ex = 0; ex < n_as; ++ex) { size_t e_start = ex*src1->ne[0]; - if (e.values.empty()) { - e.values.resize(src1->ne[0]*n_as, 0); - } - else if (e.values.size() != (size_t)src1->ne[0]*n_as) { - fprintf(stderr, "Oops: inconsistent size for %s (%d vs %d)\n", wname.c_str(), (int)e.values.size(), (int)src1->ne[0]*n_as); - exit(1); //GGML_ASSERT(false); - } - if (m_params.verbosity > 1) { - printf("%s[%d]: %32s, %s, %5d x %5d, %d\n", __func__, m_last_call, wname.c_str(), ggml_op_name(t->op), (int)src1->ne[0], (int)src1->ne[1], (int)src1->type); - } - for (int row = 0; row < (int)src1->ne[1]; ++row) { - const int excur = m_ids[row*n_as + idx]; - GGML_ASSERT(excur >= 0 && excur < n_as); // sanity check - if (excur != ex) continue; - const float * x = data + row * src1->ne[0]; - for (int j = 0; j < (int)src1->ne[0]; ++j) { - e.values[e_start + j] += x[j]*x[j]; + + for (int idx = 0; idx < n_ids; ++idx) { + for (int row = 0; row < (int)src1->ne[2]; ++row) { + const int excur = *(const int32_t *) (m_ids.data() + row*ids->nb[1] + idx*ids->nb[0]); + + GGML_ASSERT(excur >= 0 && excur < n_as); // sanity check + + if (excur != ex) continue; + + const int64_t i11 = idx % src1->ne[1]; + const int64_t i12 = row; + const float * x = (const float *)((const char *)data + i11*src1->nb[1] + i12*src1->nb[2]); + + for (int j = 0; j < (int)src1->ne[0]; ++j) { + e.values[e_start + j] += x[j]*x[j]; + } } } if (e.ncall > m_last_call) { diff --git a/examples/train-text-from-scratch/train-text-from-scratch.cpp b/examples/train-text-from-scratch/train-text-from-scratch.cpp index 7d06e401b..587418cc7 100644 --- a/examples/train-text-from-scratch/train-text-from-scratch.cpp +++ b/examples/train-text-from-scratch/train-text-from-scratch.cpp @@ -73,6 +73,7 @@ struct my_llama_model { static const char * LLM_KV_TRAINING_TYPE_TRAIN_MODEL = "train_model"; static const char * LLM_KV_TRAINING_TYPE = "training.type"; +static const char * LLM_KV_GENERAL_NAME = "general.name"; static const char * LLM_KV_GENERAL_ARCHITECTURE = "general.architecture"; static const char * LLM_KV_GENERAL_FILE_TYPE = "general.file_type"; @@ -529,6 +530,7 @@ static void load_llama_model_gguf(struct gguf_context * fctx, struct ggml_contex static void save_llama_model_gguf(struct gguf_context * fctx, const char * fn_vocab_model, struct my_llama_model * model) { const char * arch = "llama"; + enum llama_ftype ftype = LLAMA_FTYPE_ALL_F32; std::vector keybuf; @@ -540,6 +542,7 @@ static void save_llama_model_gguf(struct gguf_context * fctx, const char * fn_vo // set arch gguf_set_val_str(fctx, LLM_KV_GENERAL_ARCHITECTURE, arch); + gguf_set_val_str(fctx, LLM_KV_GENERAL_NAME, arch); gguf_set_val_u32(fctx, LLM_KV_GENERAL_FILE_TYPE, ftype); // set hparams diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 2cf6c8d98..c30554f0c 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -1233,7 +1233,7 @@ static void ggml_cuda_op_mul_mat_cublas( if (compute_capability >= CC_VOLTA && (src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) && ggml_is_contiguous(src0) && row_diff == src0->ne[1] && dst->op_params[0] == GGML_PREC_DEFAULT) { // convert src0 and src1 to fp16, multiply as fp16, convert dst to fp32 - ggml_cuda_pool_alloc src0_as_f16(ctx.pool()); + ggml_cuda_pool_alloc src0_as_f16(ctx.pool(id)); if (src0->type != GGML_TYPE_F16) { const to_fp16_cuda_t to_fp16_cuda = ggml_get_to_fp16_cuda(src0->type); GGML_ASSERT(to_fp16_cuda != nullptr); @@ -1243,7 +1243,7 @@ static void ggml_cuda_op_mul_mat_cublas( } const half * src0_ptr = src0->type == GGML_TYPE_F16 ? (const half *) src0_dd_i : src0_as_f16.get(); - ggml_cuda_pool_alloc src1_as_f16(ctx.pool()); + ggml_cuda_pool_alloc src1_as_f16(ctx.pool(id)); if (src1->type != GGML_TYPE_F16) { const to_fp16_cuda_t to_fp16_cuda = ggml_get_to_fp16_cuda(src1->type); GGML_ASSERT(to_fp16_cuda != nullptr); @@ -1252,7 +1252,7 @@ static void ggml_cuda_op_mul_mat_cublas( to_fp16_cuda(src1_ddf_i, src1_as_f16.get(), ne, stream); } const half * src1_ptr = src1->type == GGML_TYPE_F16 ? (const half *) src1_ddf_i : src1_as_f16.get(); - ggml_cuda_pool_alloc dst_f16(ctx.pool(), row_diff*src1_ncols); + ggml_cuda_pool_alloc dst_f16(ctx.pool(id), row_diff*src1_ncols); const half alpha_f16 = 1.0f; const half beta_f16 = 0.0f; @@ -1962,20 +1962,73 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor } } +struct mmid_row_mapping { + int32_t i1; + int32_t i2; +}; + +static __global__ void k_copy_src1_to_contiguous(const char * __restrict__ src1_original, char * __restrict__ src1_contiguous, + int * __restrict__ cur_src1_row, mmid_row_mapping * __restrict__ row_mapping, + const char * __restrict ids, int64_t i02, size_t ids_nb1, size_t ids_nb0, + int64_t ne11, int64_t ne10, + size_t nb11, size_t nb12) { + int32_t iid1 = blockIdx.x; + int32_t id = blockIdx.y; + + const int32_t row_id_i = *(const int32_t *) (ids + iid1*ids_nb1 + id*ids_nb0); + + if (row_id_i != i02) { + return; + } + + const int64_t i11 = id % ne11; + const int64_t i12 = iid1; + + __shared__ int src1_row; + if (threadIdx.x == 0) { + src1_row = atomicAdd(cur_src1_row, 1); + row_mapping[src1_row] = {id, iid1}; + } + __syncthreads(); + + const float * src1_row_original = (const float *)(src1_original + i11*nb11 + i12*nb12); + float * src1_row_contiguous = (float *)(src1_contiguous + src1_row*nb11); + + for (int i = threadIdx.x; i < ne10; i += blockDim.x) { + src1_row_contiguous[i] = src1_row_original[i]; + } +} + +static __global__ void k_copy_dst_from_contiguous(char * __restrict__ dst_original, const char * __restrict__ dst_contiguous, + const mmid_row_mapping * __restrict__ row_mapping, + int64_t ne0, + size_t nb1, size_t nb2) { + int32_t i = blockIdx.x; + + const int32_t i1 = row_mapping[i].i1; + const int32_t i2 = row_mapping[i].i2; + + const float * dst_row_contiguous = (const float *)(dst_contiguous + i*nb1); + float * dst_row_original = (float *)(dst_original + i1*nb1 + i2*nb2); + + for (int j = threadIdx.x; j < ne0; j += blockDim.x) { + dst_row_original[j] = dst_row_contiguous[j]; + } +} + static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * src0 = dst->src[0]; const ggml_tensor * src1 = dst->src[1]; const ggml_tensor * ids = dst->src[2]; + GGML_TENSOR_BINARY_OP_LOCALS + GGML_ASSERT(!ggml_backend_buffer_is_cuda_split(src0->buffer) && "mul_mat_id does not support split buffers"); cudaStream_t stream = ctx.stream(); - const size_t nb11 = src1->nb[1]; - const size_t nb1 = dst->nb[1]; - - const int32_t id = ((int32_t *) dst->op_params)[0]; - const int32_t n_as = src0->ne[2]; + const int64_t n_as = ne02; + const int64_t n_ids = ids->ne[0]; std::vector ids_host(ggml_nbytes(ids)); const char * ids_dev = (const char *) ids->data; @@ -1984,7 +2037,7 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor * ggml_tensor src0_row = *src0; ggml_tensor src1_row = *src1; - ggml_tensor dst_row = *dst; + ggml_tensor dst_row = *dst; char * src0_original = (char *) src0->data; char * src1_original = (char *) src1->data; @@ -1992,19 +2045,39 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor * src0_row.ne[2] = 1; src0_row.ne[3] = 1; - src0_row.nb[3] = src0->nb[2]; + src0_row.nb[3] = nb02; - if (src1->ne[1] == 1) { - for (int64_t i01 = 0; i01 < ids->ne[1]; i01++) { - const int32_t row_id = *(const int32_t *) (ids_host.data() + i01*ids->nb[1] + id*ids->nb[0]); + src1_row.ne[1] = 1; + src1_row.ne[2] = 1; + src1_row.ne[3] = 1; + src1_row.nb[2] = nb11; + src1_row.nb[3] = nb11; - GGML_ASSERT(row_id >= 0 && row_id < n_as); + dst_row.ne[1] = 1; + dst_row.ne[2] = 1; + dst_row.ne[3] = 1; + dst_row.nb[2] = nb1; + dst_row.nb[3] = nb1; - src0_row.data = src0_original + row_id*src0->nb[2]; - src1_row.data = src1_original + i01*src1->nb[1]; - dst_row.data = dst_original + i01*dst->nb[1]; + if (ne12 == 1) { + for (int64_t iid1 = 0; iid1 < ids->ne[1]; iid1++) { + for (int64_t id = 0; id < n_ids; id++) { + const int32_t i02 = *(const int32_t *) (ids_host.data() + iid1*ids->nb[1] + id*ids->nb[0]); - ggml_cuda_mul_mat(ctx, &src0_row, &src1_row, &dst_row); + GGML_ASSERT(i02 >= 0 && i02 < n_as); + + const int64_t i11 = id % ne11; + const int64_t i12 = iid1; + + const int64_t i1 = id; + const int64_t i2 = i12; + + src0_row.data = src0_original + i02*nb02; + src1_row.data = src1_original + i11*nb11 + i12*nb12; + dst_row.data = dst_original + i1*nb1 + i2*nb2; + + ggml_cuda_mul_mat(ctx, &src0_row, &src1_row, &dst_row); + } } } else { ggml_cuda_pool_alloc src1_contiguous(ctx.pool(), sizeof(float)*ggml_nelements(src1)); @@ -2013,54 +2086,69 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor * src1_row.data = src1_contiguous.get(); dst_row.data = dst_contiguous.get(); - for (int32_t row_id = 0; row_id < n_as; ++row_id) { + for (int64_t i02 = 0; i02 < n_as; i02++) { int64_t num_src1_rows = 0; - for (int64_t i01 = 0; i01 < ids->ne[1]; i01++) { - const int32_t row_id_i = *(const int32_t *) (ids_host.data() + i01*ids->nb[1] + id*ids->nb[0]); - if (row_id_i != row_id) { - continue; + for (int64_t iid1 = 0; iid1 < ids->ne[1]; iid1++) { + for (int64_t id = 0; id < n_ids; id++) { + const int32_t row_id_i = *(const int32_t *) (ids_host.data() + iid1*ids->nb[1] + id*ids->nb[0]); + + GGML_ASSERT(row_id_i >= 0 && row_id_i < n_as); + + if (row_id_i != i02) { + continue; + } + + num_src1_rows++; } - - GGML_ASSERT(row_id >= 0 && row_id < n_as); - - CUDA_CHECK(cudaMemcpyAsync(src1_contiguous.get() + num_src1_rows*nb11, src1_original + i01*nb11, - nb11, cudaMemcpyDeviceToDevice, stream)); - num_src1_rows++; } if (num_src1_rows == 0) { continue; } - src0_row.data = src0_original + row_id*src0->nb[2]; + ggml_cuda_pool_alloc dev_cur_src1_row(ctx.pool(), 1); + ggml_cuda_pool_alloc dev_row_mapping(ctx.pool(), num_src1_rows); + CUDA_CHECK(cudaMemsetAsync(dev_cur_src1_row.get(), 0, sizeof(int), stream)); + + { + dim3 block_dims(std::min((unsigned int)ne10, 768u)); + dim3 grid_dims(ids->ne[1], n_ids); + k_copy_src1_to_contiguous<<>>( + src1_original, src1_contiguous.get(), + dev_cur_src1_row.get(), dev_row_mapping.get(), + ids_dev, i02, ids->nb[1], ids->nb[0], + ne11, ne10, + nb11, nb12); + CUDA_CHECK(cudaGetLastError()); + } + + src0_row.data = src0_original + i02*nb02; + + GGML_ASSERT(nb11 == sizeof(float)*ne10); + GGML_ASSERT(nb1 == sizeof(float)*ne0); src1_row.ne[1] = num_src1_rows; - dst_row.ne[1] = num_src1_rows; - src1_row.nb[1] = nb11; src1_row.nb[2] = num_src1_rows*nb11; src1_row.nb[3] = num_src1_rows*nb11; + dst_row.ne[1] = num_src1_rows; dst_row.nb[1] = nb1; dst_row.nb[2] = num_src1_rows*nb1; dst_row.nb[3] = num_src1_rows*nb1; ggml_cuda_mul_mat(ctx, &src0_row, &src1_row, &dst_row); - num_src1_rows = 0; - for (int64_t i01 = 0; i01 < ids->ne[1]; i01++) { - const int32_t row_id_i = *(const int32_t *) (ids_host.data() + i01*ids->nb[1] + id*ids->nb[0]); - - if (row_id_i != row_id) { - continue; - } - - GGML_ASSERT(row_id >= 0 && row_id < n_as); - - CUDA_CHECK(cudaMemcpyAsync(dst_original + i01*nb1, dst_contiguous.get() + num_src1_rows*nb1, - nb1, cudaMemcpyDeviceToDevice, stream)); - num_src1_rows++; + { + dim3 block_dims(std::min((unsigned int)ne0, 768u)); + dim3 grid_dims(num_src1_rows); + k_copy_dst_from_contiguous<<>>( + dst_original, dst_contiguous.get(), + dev_row_mapping.get(), + ne0, + nb1, nb2); + CUDA_CHECK(cudaGetLastError()); } } } @@ -2493,7 +2581,8 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons GGML_CALL static bool ggml_backend_cuda_offload_op(ggml_backend_t backend, const ggml_tensor * op) { const int min_batch_size = 32; - return op->ne[1] >= min_batch_size && op->op != GGML_OP_GET_ROWS; + return (op->ne[1] >= min_batch_size && op->op != GGML_OP_GET_ROWS) || + (op->ne[2] >= min_batch_size && op->op == GGML_OP_MUL_MAT_ID); GGML_UNUSED(backend); } diff --git a/ggml-cuda/binbcast.cu b/ggml-cuda/binbcast.cu index 959eaed95..19b08b74f 100644 --- a/ggml-cuda/binbcast.cu +++ b/ggml-cuda/binbcast.cu @@ -22,6 +22,7 @@ static __global__ void k_bin_bcast(const src0_t * src0, const src1_t * src1, dst int ne0, int ne1, int ne2, int ne3, int ne10, int ne11, int ne12, int ne13, /*int s0, */ int s1, int s2, int s3, + /*int s00,*/ int s01, int s02, int s03, /*int s10,*/ int s11, int s12, int s13) { const int i0s = blockDim.x*blockIdx.x + threadIdx.x; const int i1 = (blockDim.y*blockIdx.y + threadIdx.y); @@ -36,9 +37,9 @@ static __global__ void k_bin_bcast(const src0_t * src0, const src1_t * src1, dst const int i12 = i2 % ne12; const int i13 = i3 % ne13; - const size_t i_src0 = i3*s3 + i2*s2 + i1*s1; + const size_t i_src0 = i3*s03 + i2*s02 + i1*s01; const size_t i_src1 = i13*s13 + i12*s12 + i11*s11; - const size_t i_dst = i_src0; + const size_t i_dst = i3*s3 + i2*s2 + i1*s1; const src0_t * src0_row = src0 + i_src0; const src1_t * src1_row = src1 + i_src1; @@ -55,6 +56,7 @@ static __global__ void k_bin_bcast_unravel(const src0_t * src0, const src1_t * s int ne0, int ne1, int ne2, int ne3, int ne10, int ne11, int ne12, int ne13, /*int s0, */ int s1, int s2, int s3, + /*int s00,*/ int s01, int s02, int s03, /*int s10,*/ int s11, int s12, int s13) { const int i = blockDim.x*blockIdx.x + threadIdx.x; @@ -72,9 +74,9 @@ static __global__ void k_bin_bcast_unravel(const src0_t * src0, const src1_t * s const int i12 = i2 % ne12; const int i13 = i3 % ne13; - const size_t i_src0 = i3*s3 + i2*s2 + i1*s1; + const size_t i_src0 = i3*s03 + i2*s02 + i1*s01; const size_t i_src1 = i13*s13 + i12*s12 + i11*s11; - const size_t i_dst = i_src0; + const size_t i_dst = i3*s3 + i2*s2 + i1*s1; const src0_t * src0_row = src0 + i_src0; const src1_t * src1_row = src1 + i_src1; @@ -101,10 +103,14 @@ struct bin_bcast_cuda { int nr[4] = { nr0, nr1, nr2, nr3 }; // collapse dimensions until first broadcast dimension - int64_t cne0[] = {ne0, ne1, ne2, ne3}; + int64_t cne[] = {ne0, ne1, ne2, ne3}; + int64_t cne0[] = {ne00, ne01, ne02, ne03}; int64_t cne1[] = {ne10, ne11, ne12, ne13}; - size_t cnb0[] = {nb0, nb1, nb2, nb3}; + + size_t cnb[] = {nb0, nb1, nb2, nb3}; + size_t cnb0[] = {nb00, nb01, nb02, nb03}; size_t cnb1[] = {nb10, nb11, nb12, nb13}; + auto collapse = [](int64_t cne[]) { cne[0] *= cne[1]; cne[1] = cne[2]; @@ -118,32 +124,47 @@ struct bin_bcast_cuda { cnb[3] *= cne[3]; }; - for (int i = 0; i < 4; i++) { - if (nr[i] != 1) { - break; - } - if (i > 0) { - collapse_nb(cnb0, cne0); - collapse_nb(cnb1, cne1); - collapse(cne0); - collapse(cne1); + if (ggml_is_contiguous(src0) && ggml_is_contiguous(src1) && ggml_is_contiguous(dst)) { + for (int i = 0; i < 4; i++) { + if (nr[i] != 1) { + break; + } + if (i > 0) { + collapse_nb(cnb, cne); + collapse_nb(cnb0, cne0); + collapse_nb(cnb1, cne1); + collapse(cne); + collapse(cne0); + collapse(cne1); + } } } + { - int64_t ne0 = cne0[0]; - int64_t ne1 = cne0[1]; - int64_t ne2 = cne0[2]; - int64_t ne3 = cne0[3]; + int64_t ne0 = cne[0]; + int64_t ne1 = cne[1]; + int64_t ne2 = cne[2]; + int64_t ne3 = cne[3]; + + //int64_t ne00 = cne0[0]; GGML_UNUSED(ne00); + //int64_t ne01 = cne0[1]; GGML_UNUSED(ne01); + //int64_t ne02 = cne0[2]; GGML_UNUSED(ne02); + //int64_t ne03 = cne0[3]; GGML_UNUSED(ne03); int64_t ne10 = cne1[0]; int64_t ne11 = cne1[1]; int64_t ne12 = cne1[2]; int64_t ne13 = cne1[3]; - size_t nb0 = cnb0[0]; - size_t nb1 = cnb0[1]; - size_t nb2 = cnb0[2]; - size_t nb3 = cnb0[3]; + size_t nb0 = cnb[0]; + size_t nb1 = cnb[1]; + size_t nb2 = cnb[2]; + size_t nb3 = cnb[3]; + + size_t nb00 = cnb0[0]; + size_t nb01 = cnb0[1]; + size_t nb02 = cnb0[2]; + size_t nb03 = cnb0[3]; size_t nb10 = cnb1[0]; size_t nb11 = cnb1[1]; @@ -160,7 +181,28 @@ struct bin_bcast_cuda { size_t s12 = nb12 / sizeof(src1_t); size_t s13 = nb13 / sizeof(src1_t); + size_t s00 = nb00 / sizeof(src0_t); + size_t s01 = nb01 / sizeof(src0_t); + size_t s02 = nb02 / sizeof(src0_t); + size_t s03 = nb03 / sizeof(src0_t); + + GGML_ASSERT(nb0 % sizeof(dst_t) == 0); + GGML_ASSERT(nb1 % sizeof(dst_t) == 0); + GGML_ASSERT(nb2 % sizeof(dst_t) == 0); + GGML_ASSERT(nb3 % sizeof(dst_t) == 0); + + GGML_ASSERT(nb00 % sizeof(src0_t) == 0); + GGML_ASSERT(nb01 % sizeof(src0_t) == 0); + GGML_ASSERT(nb02 % sizeof(src0_t) == 0); + GGML_ASSERT(nb03 % sizeof(src0_t) == 0); + + GGML_ASSERT(nb10 % sizeof(src1_t) == 0); + GGML_ASSERT(nb11 % sizeof(src1_t) == 0); + GGML_ASSERT(nb12 % sizeof(src1_t) == 0); + GGML_ASSERT(nb13 % sizeof(src1_t) == 0); + GGML_ASSERT(s0 == 1); + GGML_ASSERT(s00 == 1); GGML_ASSERT(s10 == 1); const int block_size = 128; @@ -179,13 +221,14 @@ struct bin_bcast_cuda { ); if (block_nums.z > 65535) { - // this is the maximum number of blocks in z direction, fallback to 1D grid kernel + // this is the maximum number of blocks in z dimension, fallback to 1D grid kernel int block_num = (ne0*ne1*ne2*ne3 + block_size - 1) / block_size; k_bin_bcast_unravel<<>>( src0_dd, src1_dd, dst_dd, ne0, ne1, ne2, ne3, ne10, ne11, ne12, ne13, /* s0, */ s1, s2, s3, + /* s00, */ s01, s02, s03, /* s10, */ s11, s12, s13); } else { k_bin_bcast<<>>( @@ -193,6 +236,7 @@ struct bin_bcast_cuda { ne0, ne1, ne2, ne3, ne10, ne11, ne12, ne13, /* s0, */ s1, s2, s3, + /* s00, */ s01, s02, s03, /* s10, */ s11, s12, s13); } } diff --git a/ggml-cuda/convert.cu b/ggml-cuda/convert.cu index ed4fa2748..b15e35782 100644 --- a/ggml-cuda/convert.cu +++ b/ggml-cuda/convert.cu @@ -45,6 +45,8 @@ static __global__ void dequantize_block_q8_0_f16(const void * __restrict__ vx, h vals[ix] = x0[ix]; } + __syncthreads(); + #pragma unroll for (int iy = 0; iy < CUDA_Q8_0_NE_ALIGN; iy += 2*WARP_SIZE) { if (need_check && i0 + iy + 2*threadIdx.x >= k) { diff --git a/ggml-metal.m b/ggml-metal.m index f4a831b52..44a3d05f3 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -1754,15 +1754,10 @@ static enum ggml_status ggml_metal_graph_compute( } break; case GGML_OP_MUL_MAT_ID: { - //GGML_ASSERT(ne00 == ne10); - //GGML_ASSERT(ne03 == ne13); const int n_as = src0->ne[2]; - // max size of the src1ids array in the kernel shared buffer - GGML_ASSERT(ne11 <= 4096); - // src2 = ids - const int64_t ne20 = src2->ne[0]; GGML_UNUSED(ne20); + const int64_t ne20 = src2->ne[0]; const int64_t ne21 = src2->ne[1]; const int64_t ne22 = src2->ne[2]; GGML_UNUSED(ne22); const int64_t ne23 = src2->ne[3]; GGML_UNUSED(ne23); @@ -1783,15 +1778,13 @@ static enum ggml_status ggml_metal_graph_compute( // find the break-even point where the matrix-matrix kernel becomes more efficient compared // to the matrix-vector kernel - int ne11_mm_min = n_as; + // ne20 = n_used_experts + // ne21 = n_rows + const int dst_rows = ne20*ne21; + const int dst_rows_min = n_as; - const int idx = ((int32_t *) dst->op_params)[0]; - - // batch size - GGML_ASSERT(ne21 == ne11); // ? - GGML_ASSERT(ne12 == 1 && ne13 == 1); // no broadcasting - const uint r2 = 1; - const uint r3 = 1; + // max size of the rowids array in the kernel shared buffer + GGML_ASSERT(dst_rows <= 2048); // for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs // AMD GPU and older A-chips will reuse matrix-vector multiplication kernel @@ -1801,7 +1794,7 @@ static enum ggml_status ggml_metal_graph_compute( // !!! if ([ctx->device supportsFamily:MTLGPUFamilyApple7] && ne00 % 32 == 0 && ne00 >= 64 && - ne11 > ne11_mm_min) { + dst_rows > dst_rows_min) { // some Metal matrix data types require aligned pointers // ref: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf (Table 2.5) @@ -1843,26 +1836,26 @@ static enum ggml_status ggml_metal_graph_compute( [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; [encoder setBuffer:id_src2 offset:offs_src2 atIndex:3]; - [encoder setBytes:&nb21 length:sizeof(nb21) atIndex:4]; - [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:5]; - [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:6]; - [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7]; - [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8]; - [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:9]; - [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:10]; - [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:11]; - [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:12]; - [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:13]; - [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:14]; - [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:15]; - [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:16]; - [encoder setBytes:&r2 length:sizeof(r2) atIndex:17]; - [encoder setBytes:&r3 length:sizeof(r3) atIndex:18]; - [encoder setBytes:&idx length:sizeof(idx) atIndex:19]; + [encoder setBytes:&ne20 length:sizeof(ne20) atIndex:4]; + [encoder setBytes:&ne21 length:sizeof(ne21) atIndex:5]; + [encoder setBytes:&nb21 length:sizeof(nb21) atIndex:6]; + [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:7]; + [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:8]; + [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:9]; + [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:10]; + [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:11]; + [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:12]; + [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:13]; + [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:14]; + [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:15]; + [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:16]; + [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:17]; + [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:18]; + [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:19]; - [encoder setThreadgroupMemoryLength:GGML_PAD(8192 + 2*ne11, 16) atIndex:0]; + [encoder setThreadgroupMemoryLength:GGML_PAD(8192 + dst_rows*4/*sizeof(ushort2)*/, 16) atIndex:0]; - [encoder dispatchThreadgroups:MTLSizeMake((ne11 + 31)/32, (ne01 + 63)/64, n_as*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)]; + [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 31)/32, (ne01 + 63)/64, n_as) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)]; } else { int nth0 = 32; int nth1 = 1; @@ -2015,72 +2008,72 @@ static enum ggml_status ggml_metal_graph_compute( GGML_ASSERT(ne00 >= nth0*nth1); } - const int64_t _ne1 = 1; // kernels needs a reference in constant memory - [encoder setComputePipelineState:pipeline]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; [encoder setBuffer:id_src2 offset:offs_src2 atIndex:3]; - [encoder setBytes:&nb21 length:sizeof(nb21) atIndex:4]; - [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:5]; - [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:6]; - [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:7]; - [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:8]; - [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:9]; - [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:10]; - [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:11]; - [encoder setBytes:&_ne1 length:sizeof(_ne1) atIndex:12]; - [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:13]; - [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:14]; - [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:15]; - [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:16]; - [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:17]; - [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:18]; - [encoder setBytes:&_ne1 length:sizeof(_ne1) atIndex:19]; - [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:20]; - [encoder setBytes:&r2 length:sizeof(r2) atIndex:21]; - [encoder setBytes:&r3 length:sizeof(r3) atIndex:22]; - [encoder setBytes:&idx length:sizeof(idx) atIndex:23]; + [encoder setBytes:&ne20 length:sizeof(ne20) atIndex:4]; + [encoder setBytes:&ne21 length:sizeof(ne21) atIndex:5]; + [encoder setBytes:&nb21 length:sizeof(nb21) atIndex:6]; + [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:7]; + [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:8]; + [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:9]; + [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:10]; + [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:11]; + [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:12]; + [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:13]; + [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:14]; + [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:15]; + [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:16]; + [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:17]; + [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:18]; + [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:19]; + [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:20]; + [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:21]; + [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:22]; + + const int64_t _ne1 = 1; + const int tgz = dst_rows; if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 || src0t == GGML_TYPE_Q5_0 || src0t == GGML_TYPE_Q5_1 || src0t == GGML_TYPE_Q8_0 || src0t == GGML_TYPE_Q2_K || src0t == GGML_TYPE_IQ1_S || src0t == GGML_TYPE_IQ1_M || src0t == GGML_TYPE_IQ2_S) { - [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1, ne21*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; + [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; } else if (src0t == GGML_TYPE_IQ2_XXS || src0t == GGML_TYPE_IQ2_XS) { const int mem_size = src0t == GGML_TYPE_IQ2_XXS ? 256*8+128 : 512*8+128; [encoder setThreadgroupMemoryLength:mem_size atIndex:0]; - [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1, ne21*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; + [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; } else if (src0t == GGML_TYPE_IQ3_XXS || src0t == GGML_TYPE_IQ3_S) { const int mem_size = src0t == GGML_TYPE_IQ3_XXS ? 256*4+128 : 512*4; [encoder setThreadgroupMemoryLength:mem_size atIndex:0]; - [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1, ne21*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; + [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; } else if (src0t == GGML_TYPE_IQ4_NL || src0t == GGML_TYPE_IQ4_XS) { const int mem_size = 32*sizeof(float); [encoder setThreadgroupMemoryLength:mem_size atIndex:0]; - [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, ne21*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; + [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; } else if (src0t == GGML_TYPE_Q4_K) { - [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, ne21*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; + [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; } else if (src0t == GGML_TYPE_Q3_K) { #ifdef GGML_QKK_64 - [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, _ne1, ne21*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; + [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; #else - [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, ne21*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; + [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; #endif } else if (src0t == GGML_TYPE_Q5_K) { - [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, ne21*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; + [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; } else if (src0t == GGML_TYPE_Q6_K) { - [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, _ne1, ne21*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; + [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; } else { - const int64_t ny = (_ne1 + nrows - 1)/nrows; - [encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, ne21*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; + const int64_t ny = (_ne1 + nrows - 1)/nrows; // = _ne1 + [encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; } } } break; diff --git a/ggml-metal.metal b/ggml-metal.metal index 36b87b2f0..1ed5632b4 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -892,16 +892,16 @@ void mul_vec_q_n_f32_impl( device const void * src0, device const float * src1, device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant int64_t & ne10, - constant int64_t & ne12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, - threadgroup int8_t * shared_values, + int64_t ne00, + int64_t ne01, + int64_t ne02, + int64_t ne10, + int64_t ne12, + int64_t ne0, + int64_t ne1, + uint r2, + uint r3, + threadgroup int8_t * shared_values, uint3 tgpig, uint tiisg, uint sgitg) { const int nb = ne00/QK4_0; @@ -1066,19 +1066,19 @@ void kernel_mul_mv_q8_0_f32_impl( device const void * src0, device const float * src1, device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant int64_t & ne10, - constant int64_t & ne12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, - threadgroup int8_t * shared_values [[threadgroup(0)]], - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { + int64_t ne00, + int64_t ne01, + int64_t ne02, + int64_t ne10, + int64_t ne12, + int64_t ne0, + int64_t ne1, + uint r2, + uint r3, + threadgroup int8_t * shared_values, + uint3 tgpig, + uint tiisg, + uint sgitg) { const int nr = N_DST; const int nsg = N_SIMDGROUP; const int nw = N_SIMDWIDTH; @@ -1165,24 +1165,24 @@ void kernel_mul_mv_f32_f32_impl( device const char * src0, device const char * src1, device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]]) { + int64_t ne00, + int64_t ne01, + int64_t ne02, + uint64_t nb00, + uint64_t nb01, + uint64_t nb02, + int64_t ne10, + int64_t ne11, + int64_t ne12, + uint64_t nb10, + uint64_t nb11, + uint64_t nb12, + int64_t ne0, + int64_t ne1, + uint r2, + uint r3, + uint3 tgpig, + uint tiisg) { const int64_t r0 = tgpig.x; const int64_t rb = tgpig.y*N_F32_F32; @@ -1435,24 +1435,24 @@ void kernel_mul_mv_f16_f32_impl( device const char * src0, device const char * src1, device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]]) { + int64_t ne00, + int64_t ne01, + int64_t ne02, + uint64_t nb00, + uint64_t nb01, + uint64_t nb02, + int64_t ne10, + int64_t ne11, + int64_t ne12, + uint64_t nb10, + uint64_t nb11, + uint64_t nb12, + int64_t ne0, + int64_t ne1, + uint r2, + uint r3, + uint3 tgpig, + uint tiisg) { const int64_t r0 = tgpig.x; const int64_t rb = tgpig.y*N_F16_F32; @@ -3419,19 +3419,19 @@ void kernel_mul_mv_q2_K_f32_impl( device const void * src0, device const float * src1, device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant int64_t & ne10, - constant int64_t & ne12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, - threadgroup int8_t * shared_values [[threadgroup(0)]], - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { + int64_t ne00, + int64_t ne01, + int64_t ne02, + int64_t ne10, + int64_t ne12, + int64_t ne0, + int64_t ne1, + uint r2, + uint r3, + threadgroup int8_t * shared_values, + uint3 tgpig, + uint tiisg, + uint sgitg) { const int nb = ne00/QK_K; const int r0 = tgpig.x; @@ -3599,19 +3599,19 @@ void kernel_mul_mv_q3_K_f32_impl( device const void * src0, device const float * src1, device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant int64_t & ne10, - constant int64_t & ne12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, - threadgroup int8_t * shared_values [[threadgroup(0)]], - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { + int64_t ne00, + int64_t ne01, + int64_t ne02, + int64_t ne10, + int64_t ne12, + int64_t ne0, + int64_t ne1, + uint r2, + uint r3, + threadgroup int8_t * shared_values, + uint3 tgpig, + uint tiisg, + uint sgitg) { const int nb = ne00/QK_K; @@ -3865,19 +3865,19 @@ void kernel_mul_mv_q4_K_f32_impl( device const void * src0, device const float * src1, device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant int64_t & ne10, - constant int64_t & ne12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, - threadgroup int8_t * shared_values [[threadgroup(0)]], - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { + int64_t ne00, + int64_t ne01, + int64_t ne02, + int64_t ne10, + int64_t ne12, + int64_t ne0, + int64_t ne1, + uint r2, + uint r3, + threadgroup int8_t * shared_values, + uint3 tgpig, + uint tiisg, + uint sgitg) { const uint16_t kmask1 = 0x3f3f; const uint16_t kmask2 = 0x0f0f; @@ -4104,19 +4104,19 @@ void kernel_mul_mv_q5_K_f32_impl( device const void * src0, device const float * src1, device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant int64_t & ne10, - constant int64_t & ne12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, - threadgroup int8_t * shared_values [[threadgroup(0)]], - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { + int64_t ne00, + int64_t ne01, + int64_t ne02, + int64_t ne10, + int64_t ne12, + int64_t ne0, + int64_t ne1, + uint r2, + uint r3, + threadgroup int8_t * shared_values, + uint3 tgpig, + uint tiisg, + uint sgitg) { const int nb = ne00/QK_K; @@ -4311,19 +4311,19 @@ void kernel_mul_mv_q6_K_f32_impl( device const void * src0, device const float * src1, device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant int64_t & ne10, - constant int64_t & ne12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, - threadgroup int8_t * shared_values [[threadgroup(0)]], - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { + int64_t ne00, + int64_t ne01, + int64_t ne02, + int64_t ne10, + int64_t ne12, + int64_t ne0, + int64_t ne1, + uint r2, + uint r3, + threadgroup int8_t * shared_values, + uint3 tgpig, + uint tiisg, + uint sgitg) { const uint8_t kmask1 = 0x03; const uint8_t kmask2 = 0x0C; @@ -4448,19 +4448,19 @@ void kernel_mul_mv_iq2_xxs_f32_impl( device const void * src0, device const float * src1, device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant int64_t & ne10, - constant int64_t & ne12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, - threadgroup int8_t * shared_values [[threadgroup(0)]], - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { + int64_t ne00, + int64_t ne01, + int64_t ne02, + int64_t ne10, + int64_t ne12, + int64_t ne0, + int64_t ne1, + uint r2, + uint r3, + threadgroup int8_t * shared_values, + uint3 tgpig, + uint tiisg, + uint sgitg) { const int nb = ne00/QK_K; const int r0 = tgpig.x; @@ -4577,19 +4577,19 @@ void kernel_mul_mv_iq2_xs_f32_impl( device const void * src0, device const float * src1, device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant int64_t & ne10, - constant int64_t & ne12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, - threadgroup int8_t * shared_values [[threadgroup(0)]], - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { + int64_t ne00, + int64_t ne01, + int64_t ne02, + int64_t ne10, + int64_t ne12, + int64_t ne0, + int64_t ne1, + uint r2, + uint r3, + threadgroup int8_t * shared_values, + uint3 tgpig, + uint tiisg, + uint sgitg) { const int nb = ne00/QK_K; const int r0 = tgpig.x; @@ -4716,19 +4716,19 @@ void kernel_mul_mv_iq3_xxs_f32_impl( device const void * src0, device const float * src1, device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant int64_t & ne10, - constant int64_t & ne12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, - threadgroup int8_t * shared_values [[threadgroup(0)]], - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { + int64_t ne00, + int64_t ne01, + int64_t ne02, + int64_t ne10, + int64_t ne12, + int64_t ne0, + int64_t ne1, + uint r2, + uint r3, + threadgroup int8_t * shared_values, + uint3 tgpig, + uint tiisg, + uint sgitg) { const int nb = ne00/QK_K; const int r0 = tgpig.x; @@ -4848,19 +4848,19 @@ void kernel_mul_mv_iq3_s_f32_impl( device const void * src0, device const float * src1, device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant int64_t & ne10, - constant int64_t & ne12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, - threadgroup int8_t * shared_values [[threadgroup(0)]], - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { + int64_t ne00, + int64_t ne01, + int64_t ne02, + int64_t ne10, + int64_t ne12, + int64_t ne0, + int64_t ne1, + uint r2, + uint r3, + threadgroup int8_t * shared_values, + uint3 tgpig, + uint tiisg, + uint sgitg) { const int nb = ne00/QK_K; const int r0 = tgpig.x; @@ -4980,19 +4980,19 @@ void kernel_mul_mv_iq2_s_f32_impl( device const void * src0, device const float * src1, device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant int64_t & ne10, - constant int64_t & ne12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, - threadgroup int8_t * shared_values [[threadgroup(0)]], - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { + int64_t ne00, + int64_t ne01, + int64_t ne02, + int64_t ne10, + int64_t ne12, + int64_t ne0, + int64_t ne1, + uint r2, + uint r3, + threadgroup int8_t * shared_values, + uint3 tgpig, + uint tiisg, + uint sgitg) { const int nb = ne00/QK_K; const int r0 = tgpig.x; @@ -5113,19 +5113,19 @@ void kernel_mul_mv_iq1_s_f32_impl( device const void * src0, device const float * src1, device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant int64_t & ne10, - constant int64_t & ne12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, - threadgroup int8_t * shared_values [[threadgroup(0)]], - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { + int64_t ne00, + int64_t ne01, + int64_t ne02, + int64_t ne10, + int64_t ne12, + int64_t ne0, + int64_t ne1, + uint r2, + uint r3, + threadgroup int8_t * shared_value, + uint3 tgpig, + uint tiisg, + uint sgitg) { const int nb = ne00/QK_K; const int r0 = tgpig.x; @@ -5203,19 +5203,19 @@ void kernel_mul_mv_iq1_m_f32_impl( device const void * src0, device const float * src1, device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant int64_t & ne10, - constant int64_t & ne12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, - threadgroup int8_t * shared_values [[threadgroup(0)]], - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { + int64_t ne00, + int64_t ne01, + int64_t ne02, + int64_t ne10, + int64_t ne12, + int64_t ne0, + int64_t ne1, + uint r2, + uint r3, + threadgroup int8_t * shared_value, + uint3 tgpig, + uint tiisg, + uint sgitg) { const int nb = ne00/QK_K; const int r0 = tgpig.x; @@ -5312,19 +5312,19 @@ void kernel_mul_mv_iq4_nl_f32_impl( device const void * src0, device const float * src1, device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant int64_t & ne10, - constant int64_t & ne12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, - threadgroup int8_t * shared_values_i8 [[threadgroup(0)]], - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { + int64_t ne00, + int64_t ne01, + int64_t ne02, + int64_t ne10, + int64_t ne12, + int64_t ne0, + int64_t ne1, + uint r2, + uint r3, + threadgroup int8_t * shared_values_i8, + uint3 tgpig, + uint tiisg, + uint sgitg) { threadgroup float * shared_values = (threadgroup float *)shared_values_i8; const int nb = ne00/QK4_NL; @@ -5407,19 +5407,20 @@ void kernel_mul_mv_iq4_xs_f32_impl( device const void * src0, device const float * src1, device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant int64_t & ne10, - constant int64_t & ne12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, - threadgroup int8_t * shared_values_i8 [[threadgroup(0)]], - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { + int64_t ne00, + int64_t ne01, + int64_t ne02, + int64_t ne10, + int64_t ne12, + int64_t ne0, + int64_t ne1, + uint r2, + uint r3, + threadgroup int8_t * shared_values_i8, + uint3 tgpig, + uint tiisg, + uint sgitg) { + threadgroup float * shared_values = (threadgroup float *)shared_values_i8; const int nb = ne00/QK_K; const int r0 = tgpig.x; @@ -6361,25 +6362,25 @@ void kernel_mul_mm_impl(device const uchar * src0, } } -// same as kernel_mul_mm_impl, but src1 and dst are accessed via indices stored in src1ids +// same as kernel_mul_mm_impl, but src1 and dst are accessed via indices stored in rowids template void kernel_mul_mm_id_impl( device const uchar * src0, device const uchar * src1, - threadgroup short * src1ids, + threadgroup ushort2 * rowids, device float * dst, constant int64_t & ne00, constant int64_t & ne02, constant uint64_t & nb01, constant uint64_t & nb02, + constant int64_t & ne11, constant int64_t & ne12, constant uint64_t & nb10, constant uint64_t & nb11, constant uint64_t & nb12, constant int64_t & ne0, int64_t ne1, - constant uint & r2, - constant uint & r3, + int64_t ne0ne1, threadgroup uchar * shared_memory, uint3 tgpig[[threadgroup_position_in_grid]], uint tiitg[[thread_index_in_threadgroup]], @@ -6390,7 +6391,6 @@ void kernel_mul_mm_id_impl( const uint r0 = tgpig.y; const uint r1 = tgpig.x; - const uint im = tgpig.z; if (r1 * BLOCK_SIZE_N >= ne1) return; @@ -6408,19 +6408,16 @@ void kernel_mul_mm_id_impl( for (int i = 0; i < 8; i++){ c_res[i] = make_filled_simdgroup_matrix(0.f); } - short il = (tiitg % THREAD_PER_ROW); - const uint i12 = im%ne12; - const uint i13 = im/ne12; - - uint offset0 = (i12/r2)*nb02 + (i13/r3)*(nb02*ne02); ushort offset1 = il/nl; - device const block_q * x = (device const block_q *)(src0 + (r0 * BLOCK_SIZE_M + thread_row) * nb01 + offset0) + offset1; + threadgroup const auto & id = rowids[r1 * BLOCK_SIZE_N + thread_col]; + + device const block_q * x = (device const block_q *)(src0 + (r0 * BLOCK_SIZE_M + thread_row) * nb01) + offset1; device const float * y = (device const float *)(src1 - + nb12 * im - + nb11 * src1ids[r1 * BLOCK_SIZE_N + thread_col] + + nb12 * id[1] + + nb11 * (id[0] % ne11) + nb10 * (BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL))); for (int loop_k = 0; loop_k < ne00; loop_k += BLOCK_SIZE_K) { @@ -6449,11 +6446,11 @@ void kernel_mul_mm_id_impl( for (int ik = 0; ik < BLOCK_SIZE_K / 8; ik++) { for (int i = 0; i < 4; i++) { - simdgroup_load(ma[i],lsma + SG_MAT_SIZE * i); + simdgroup_load(ma[i], lsma + SG_MAT_SIZE * i); } simdgroup_barrier(mem_flags::mem_none); for (int i = 0; i < 2; i++) { - simdgroup_load(mb[i],lsmb + SG_MAT_SIZE * i); + simdgroup_load(mb[i], lsmb + SG_MAT_SIZE * i); } lsma += BLOCK_SIZE_M / SG_MAT_ROW * SG_MAT_SIZE; @@ -6475,11 +6472,13 @@ void kernel_mul_mm_id_impl( threadgroup_barrier(mem_flags::mem_threadgroup); - device float * C = dst + (BLOCK_SIZE_M * r0) + im*ne1*ne0; + device float * C = dst + (BLOCK_SIZE_M * r0); if (sgitg == 0) { - for (int i = 0; i < n_rows; i++) { - for (int j = tiitg; j < n_cols; j += BLOCK_SIZE_N) { - *(C + i + src1ids[j + r1*BLOCK_SIZE_N] * ne0) = *(temp_str + i + j * BLOCK_SIZE_M); + for (int j = tiitg; j < n_cols; j += BLOCK_SIZE_N) { + threadgroup const auto & jid = rowids[r1 * BLOCK_SIZE_N + j]; + int joff = jid[0] * ne0 + jid[1] * ne0ne1; + for (int i = 0; i < n_rows; i++) { + *(C + i + joff) = *(temp_str + i + j * BLOCK_SIZE_M); } } } @@ -6534,11 +6533,14 @@ kernel void kernel_mul_mm_id( device const uchar * src1, device float * dst, device const uchar * ids, + constant int64_t & nei0, + constant int64_t & nei1, constant uint64_t & nbi1, constant int64_t & ne00, constant int64_t & ne02, constant uint64_t & nb01, constant uint64_t & nb02, + constant int64_t & ne11, constant int64_t & ne12, constant int64_t & ne13, constant uint64_t & nb10, @@ -6547,47 +6549,52 @@ kernel void kernel_mul_mm_id( constant int64_t & ne0, constant int64_t & ne1, constant uint64_t & nb1, - constant uint & r2, - constant uint & r3, - constant int & idx, threadgroup uchar * shared_memory [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], uint tiitg[[thread_index_in_threadgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - // expert id - const int32_t id = tgpig.z/(ne12*ne13); - device const uchar * src0 = src0s + id*nb02; + const int32_t i02 = tgpig.z; + tgpig.z = 0; - tgpig.z = tgpig.z%(ne12*ne13); + device const uchar * src0 = src0s + i02*nb02; - // row indices of src1 for expert id - threadgroup short * src1ids = (threadgroup short *)(shared_memory + 8192); + // row indices + threadgroup ushort2 * rowids = (threadgroup ushort2 *)(shared_memory + 8192); + // TODO: parallelize this loop int64_t _ne1 = 0; - for (int64_t i1 = 0; i1 < ne1; i1++) { - if (((device int32_t *) (ids + i1*nbi1))[idx] == id) { - src1ids[_ne1++] = i1; + for (ushort ii1 = 0; ii1 < nei1; ii1++) { + for (ushort ii0 = 0; ii0 < nei0; ii0++) { + int32_t id = ((device int32_t *) (ids + ii1*nbi1))[ii0]; + if (id == i02) { + //if (tiitg == 0) { + rowids[_ne1] = ushort2(ii0, ii1); + //} + _ne1++; + } } } + threadgroup_barrier(mem_flags::mem_threadgroup); + kernel_mul_mm_id_impl( src0, src1, - src1ids, + rowids, dst, ne00, ne02, nb01, nb02, + ne11, ne12, nb10, nb11, nb12, ne0, _ne1, - r2, - r3, + ne0*ne1, shared_memory, tgpig, tiitg, @@ -6648,24 +6655,7 @@ template [[host_name("kernel_get_rows_iq4_xs")]] kernel get_rows_t kernel_get_r // matrix-matrix multiplication // -typedef void (mat_mm_t)( - device const uchar * src0, - device const uchar * src1, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne02, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne12, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, - threadgroup uchar *, - uint3, uint, uint); +typedef decltype(kernel_mul_mm) mat_mm_t; template [[host_name("kernel_mul_mm_f32_f32")]] kernel mat_mm_t kernel_mul_mm; template [[host_name("kernel_mul_mm_f16_f32")]] kernel mat_mm_t kernel_mul_mm; @@ -6697,29 +6687,7 @@ template [[host_name("kernel_mul_mm_iq4_xs_f32")]] kernel mat_mm_t kernel_mul_m // indirect matrix-matrix multiplication // -typedef void (mat_mm_id_t)( - device const uchar * src0s, - device const uchar * src1, - device float * dst, - device const uchar * ids, - constant uint64_t & nbi1, - constant int64_t & ne00, - constant int64_t & ne02, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne12, - constant int64_t & ne13, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint64_t & nb1, - constant uint & r2, - constant uint & r3, - constant int & idx, - threadgroup uchar *, - uint3, uint, uint); +typedef decltype(kernel_mul_mm_id) mat_mm_id_t; template [[host_name("kernel_mul_mm_id_f32_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; template [[host_name("kernel_mul_mm_id_f16_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; @@ -6755,71 +6723,71 @@ typedef void (kernel_mul_mv_impl_t)( device const char * src0, device const char * src1, device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]]); + int64_t ne00, + int64_t ne01, + int64_t ne02, + uint64_t nb00, + uint64_t nb01, + uint64_t nb02, + int64_t ne10, + int64_t ne11, + int64_t ne12, + uint64_t nb10, + uint64_t nb11, + uint64_t nb12, + int64_t ne0, + int64_t ne1, + uint r2, + uint r3, + uint3 tgpig, + uint tiisg); typedef void (kernel_mul_mv2_impl_t)( device const void * src0, device const float * src1, device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant int64_t & ne10, - constant int64_t & ne12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, - threadgroup int8_t * shared_values [[threadgroup(0)]], - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]); + int64_t ne00, + int64_t ne01, + int64_t ne02, + int64_t ne10, + int64_t ne12, + int64_t ne0, + int64_t ne1, + uint r2, + uint r3, + threadgroup int8_t * shared_values, + uint3 tgpig, + uint tiisg, + uint sgitg); template void mmv_fn( device const char * src0, device const char * src1, device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant int64_t & ne13, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint64_t & nb1, - constant uint & r2, - constant uint & r3, - threadgroup int8_t * shared_values [[threadgroup(0)]], - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiitg[[thread_index_in_threadgroup]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { + int64_t ne00, + int64_t ne01, + int64_t ne02, + uint64_t nb00, + uint64_t nb01, + uint64_t nb02, + int64_t ne10, + int64_t ne11, + int64_t ne12, + int64_t ne13, + uint64_t nb10, + uint64_t nb11, + uint64_t nb12, + int64_t ne0, + int64_t ne1, + uint64_t nb1, + uint r2, + uint r3, + threadgroup int8_t * shared_values, + uint3 tgpig, + uint tiitg, + uint tiisg, + uint sgitg) { impl_fn(src0,src1,dst,ne00,ne01,ne02,nb00,nb01,nb02,ne10,ne11,ne12,nb10,nb11,nb12,ne0,ne1,r2,r3,tgpig,tiisg); } @@ -6828,59 +6796,33 @@ void mmv_fn( device const char * src0, device const char * src1, device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant int64_t & ne13, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint64_t & nb1, - constant uint & r2, - constant uint & r3, - threadgroup int8_t * shared_values [[threadgroup(0)]], - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiitg[[thread_index_in_threadgroup]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { + int64_t ne00, + int64_t ne01, + int64_t ne02, + uint64_t nb00, + uint64_t nb01, + uint64_t nb02, + int64_t ne10, + int64_t ne11, + int64_t ne12, + int64_t ne13, + uint64_t nb10, + uint64_t nb11, + uint64_t nb12, + int64_t ne0, + int64_t ne1, + uint64_t nb1, + uint r2, + uint r3, + threadgroup int8_t * shared_values, + uint3 tgpig, + uint tiitg, + uint tiisg, + uint sgitg) { impl_fn(src0,(const device float *)src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,shared_values,tgpig,tiisg,sgitg); } -typedef void (mul_mv_impl_fn_t)( - device const char * src0, - device const char * src1, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant int64_t & ne13, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint64_t & nb1, - constant uint & r2, - constant uint & r3, - threadgroup int8_t * shared_values [[threadgroup(0)]], - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiitg[[thread_index_in_threadgroup]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]); +typedef decltype(mmv_fn) mul_mv_impl_fn_t; template kernel void kernel_mul_mv_id( @@ -6888,6 +6830,8 @@ kernel void kernel_mul_mv_id( device const char * src1, device float * dst, device const char * ids, + constant int64_t & nei0, + constant int64_t & nei1, constant uint64_t & nbi1, constant int64_t & ne00, constant int64_t & ne01, @@ -6905,43 +6849,50 @@ kernel void kernel_mul_mv_id( constant int64_t & ne0, constant int64_t & ne1, constant uint64_t & nb1, - constant uint & r2, - constant uint & r3, - constant int & idx, threadgroup int8_t * shared_values [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], uint tiitg[[thread_index_in_threadgroup]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - const int64_t bid = tgpig.z/(ne12*ne13); + const int iid1 = tgpig.z/nei0; + const int idx = tgpig.z%nei0; - tgpig.z = tgpig.z%(ne12*ne13); + tgpig.z = 0; - const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx]; - device const char * src0 = src0s + id*nb02; + const int32_t i02 = ((device const int32_t *) (ids + iid1*nbi1))[idx]; + + const int64_t i11 = idx % ne11; + const int64_t i12 = iid1; + + const int64_t i1 = idx; + const int64_t i2 = i12; + + device const char * src0_cur = src0s + i02*nb02; + device const char * src1_cur = src1 + i11*nb11 + i12*nb12; + device float * dst_cur = dst + i1*ne0 + i2*ne1*ne0; impl_fn( - src0, - src1 + bid*nb11, - dst + bid*ne0, - ne00, - ne01, - ne02, - nb00, - nb01, - nb02, - ne10, - ne11, - ne12, - ne13, - nb10, - nb11, - nb12, - ne0, - ne1, - nb1, - r2, - r3, + /* src0 */ src0_cur, + /* src1 */ src1_cur, + /* dst */ dst_cur, + /* ne00 */ ne00, + /* ne01 */ ne01, + /* ne02 */ 1,//ne02, + /* nb00 */ nb00, + /* nb01 */ nb01, + /* nb02 */ nb02, + /* ne10 */ ne10, + /* ne11 */ 1,//ne11, + /* ne12 */ 1,//ne12, + /* ne13 */ 1,//ne13, + /* nb10 */ nb10, + /* nb11 */ nb11, + /* nb12 */ nb12, + /* ne0 */ ne0, + /* ne1 */ 1,//ne1, + /* nb1 */ nb1, + /* r2 */ 1, + /* r3 */ 1, shared_values, tgpig, tiitg, @@ -6949,36 +6900,7 @@ kernel void kernel_mul_mv_id( sgitg); } -typedef void (kernel_mul_mv_id_t)( - device const char * src0s, - device const char * src1, - device float * dst, - device const char * ids, - constant uint64_t & nbi1, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant int64_t & ne13, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint64_t & nb1, - constant uint & r2, - constant uint & r3, - constant int & idx, - threadgroup int8_t * shared_values [[threadgroup(0)]], - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiitg[[thread_index_in_threadgroup]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]); +typedef decltype(kernel_mul_mv_id>) kernel_mul_mv_id_t; template [[host_name("kernel_mul_mv_id_f32_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; template [[host_name("kernel_mul_mv_id_f16_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; diff --git a/ggml-sycl.cpp b/ggml-sycl.cpp index f5bb7da86..a9b310243 100644 --- a/ggml-sycl.cpp +++ b/ggml-sycl.cpp @@ -17752,7 +17752,7 @@ GGML_CALL static bool ggml_backend_sycl_supports_op(ggml_backend_t backend, cons GGML_CALL static bool ggml_backend_sycl_offload_op(ggml_backend_t backend, const ggml_tensor * op) { const int min_batch_size = 32; - return op->ne[1] >= min_batch_size && op->op != GGML_OP_GET_ROWS; + return op->ne[1] >= min_batch_size && op->op != GGML_OP_GET_ROWS && op->op != GGML_OP_MUL_MAT_ID; GGML_UNUSED(backend); } diff --git a/ggml.c b/ggml.c index f2bbfa6f2..4c749ecda 100644 --- a/ggml.c +++ b/ggml.c @@ -4642,21 +4642,32 @@ void ggml_mul_mat_set_prec( // ggml_mul_mat_id -// NOTE: id will be removed in the future and instead all the experts listed in ids will be computed -// this will allow computing all the used experts in a single matrix multiplication +/* + c = ggml_mul_mat_id(ctx, as, b, ids); + + as -> [cols, rows, n_expert] + ids -> [n_experts_used, n_tokens] (i32) + b -> [cols, n_expert_used, n_tokens] + c -> [cols, n_expert_used, n_tokens] + + in b, n_experts_used can be broadcasted to match the n_expert_used of ids + + c ~= as[:,:,i] @ b[:,i%r,t], i = ids[e,t] for all e,t in ids +*/ struct ggml_tensor * ggml_mul_mat_id( struct ggml_context * ctx, struct ggml_tensor * as, - struct ggml_tensor * ids, - int id, - struct ggml_tensor * b) { - + struct ggml_tensor * b, + struct ggml_tensor * ids) { + GGML_ASSERT(!ggml_is_transposed(as)); GGML_ASSERT(ids->type == GGML_TYPE_I32); + + GGML_ASSERT(as->ne[3] == 1); // as is 3d (one matrix per expert) + GGML_ASSERT(b->ne[3] == 1); // b is 3d GGML_ASSERT(ids->ne[2] == 1 && ids->ne[3] == 1); // ids is 2d - GGML_ASSERT(ids->ne[1] == b->ne[1]); // must have an expert per b row - GGML_ASSERT(ids->ne[2] == b->ne[2] && ids->ne[3] == b->ne[3]); - GGML_ASSERT(id >= 0 && id < ids->ne[0]); // valid id + GGML_ASSERT(ids->ne[1] == b->ne[2]); // must have an expert list per b row GGML_ASSERT(as->ne[0] == b->ne[0]); // can_mul_mat + GGML_ASSERT(ids->ne[0] % b->ne[1] == 0); // can broadcast bool is_node = false; @@ -4664,11 +4675,9 @@ struct ggml_tensor * ggml_mul_mat_id( is_node = true; } - const int64_t ne[4] = { as->ne[1], b->ne[1], b->ne[2], b->ne[3] }; + const int64_t ne[4] = { as->ne[1], ids->ne[0], b->ne[2], 1 }; struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne); - ggml_set_op_params_i32(result, 0, id); - result->op = GGML_OP_MUL_MAT_ID; result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; result->src[0] = as; @@ -11127,11 +11136,6 @@ static void ggml_compute_forward_mul_mat_id( enum ggml_type const vec_dot_type = type_traits[type].vec_dot_type; ggml_from_float_t const from_float_to_vec_dot = type_traits[vec_dot_type].from_float; - GGML_ASSERT(ne0 == ne01); - GGML_ASSERT(ne1 == ne11); - GGML_ASSERT(ne2 == ne12); - GGML_ASSERT(ne3 == ne13); - // we don't support permuted src0 or src1 GGML_ASSERT(nb00 == ggml_type_size(type)); GGML_ASSERT(nb10 == ggml_type_size(src1->type)); @@ -11142,22 +11146,21 @@ static void ggml_compute_forward_mul_mat_id( GGML_ASSERT(nb1 <= nb2); GGML_ASSERT(nb2 <= nb3); - // broadcast is not supported with mmid - assert(ne12 == 1); - assert(ne13 == 1); - // row groups - const int id = ggml_get_op_params_i32(dst, 0); - const int n_as = src0->ne[2]; + const int n_ids = ids->ne[0]; // n_expert_used + const int n_as = ne02; // n_expert char * wdata_src1_end = (src1->type == vec_dot_type) ? (char *) params->wdata : (char *) params->wdata + GGML_PAD(ggml_row_size(vec_dot_type, ggml_nelements(src1)), sizeof(int64_t)); - int64_t * matrix_row_counts = (int64_t *) (wdata_src1_end); // [n_as] - int64_t * matrix_rows = matrix_row_counts + n_as; // [n_as][ne11] + struct mmid_row_mapping { + int32_t i1; + int32_t i2; + }; - #define MMID_MATRIX_ROW(row_id, i1) matrix_rows[(row_id)*ne11 + (i1)] + int64_t * matrix_row_counts = (int64_t *) (wdata_src1_end); // [n_as] + struct mmid_row_mapping * matrix_rows = (struct mmid_row_mapping *)(matrix_row_counts + n_as); // [n_as][ne11] if (params->type == GGML_TASK_TYPE_INIT) { if (ith != 0) { @@ -11183,13 +11186,18 @@ static void ggml_compute_forward_mul_mat_id( // initialize matrix_row_counts memset(matrix_row_counts, 0, n_as*sizeof(int64_t)); - // group rows by src0 matrix - for (int64_t i01 = 0; i01 < ids->ne[1]; i01++) { - const int32_t row_id = *(const int32_t *) ((const char *) ids->data + i01*ids->nb[1] + id*ids->nb[0]); +#define MMID_MATRIX_ROW(row_id, i1) matrix_rows[(row_id)*ne12 + (i1)] - GGML_ASSERT(row_id >= 0 && row_id < n_as); - MMID_MATRIX_ROW(row_id, matrix_row_counts[row_id]) = i01; - matrix_row_counts[row_id] += 1; + // group rows by src0 matrix + for (int64_t iid1 = 0; iid1 < ids->ne[1]; ++iid1) { + for (int id = 0; id < n_ids; ++id) { + const int32_t i02 = *(const int32_t *) ((const char *) ids->data + iid1*ids->nb[1] + id*ids->nb[0]); + + assert(i02 >= 0 && i02 < n_as); + + MMID_MATRIX_ROW(i02, matrix_row_counts[i02]) = (struct mmid_row_mapping) {id, iid1}; + matrix_row_counts[i02] += 1; + } } return; @@ -11207,15 +11215,13 @@ static void ggml_compute_forward_mul_mat_id( continue; } - size_t src0_offset = cur_a*src0->nb[2]; + const char * src0_cur = (const char *) src0->data + cur_a*nb02; const void * wdata = (src1->type == vec_dot_type) ? src1->data : params->wdata; const size_t row_size = ggml_row_size(vec_dot_type, ne10); - const int64_t nr0 = ne01; // src0 rows - const int64_t nr1 = cne1*ne12*ne13; // src1 rows - - //printf("nr0 = %lld, nr1 = %lld\n", nr0, nr1); + const int64_t nr0 = ne01; // src0 rows + const int64_t nr1 = cne1; // src1 rows // distribute the thread work across the inner or outer loop based on which one is larger @@ -11234,13 +11240,11 @@ static void ggml_compute_forward_mul_mat_id( const int64_t ir110 = dr1*ith1; const int64_t ir111 = MIN(ir110 + dr1, nr1); - //printf("ir010 = %6lld, ir011 = %6lld, ir110 = %6lld, ir111 = %6lld\n", ir010, ir011, ir110, ir111); - // threads with no work simply yield (not sure if it helps) - if (ir010 >= ir011 || ir110 >= ir111) { - sched_yield(); - continue; - } + //if (ir010 >= ir011 || ir110 >= ir111) { + // sched_yield(); + // continue; + //} // block-tiling attempt const int64_t blck_0 = 16; @@ -11252,20 +11256,16 @@ static void ggml_compute_forward_mul_mat_id( for (int64_t iir1 = ir110; iir1 < ir111; iir1 += blck_1) { for (int64_t iir0 = ir010; iir0 < ir011; iir0 += blck_0) { for (int64_t ir1 = iir1; ir1 < iir1 + blck_1 && ir1 < ir111; ++ir1) { - const int64_t i13 = (ir1/(ne12*cne1)); // Note: currently, src1 is always a matrix - const int64_t i12 = (ir1 - i13*ne12*cne1)/cne1; - const int64_t _i11 = (ir1 - i13*ne12*cne1 - i12*cne1); - const int64_t i11 = MMID_MATRIX_ROW(cur_a, _i11); + const int64_t _i12 = ir1; // logical row index for this expert - // broadcast src0 into src1 - //const int64_t i03 = i13/r3; - //const int64_t i02 = i12/r2; + struct mmid_row_mapping row_mapping = MMID_MATRIX_ROW(cur_a, _i12); + const int id = row_mapping.i1; // selected expert index - const int64_t i1 = i11; - const int64_t i2 = i12; - const int64_t i3 = i13; + const int64_t i11 = id % ne11; + const int64_t i12 = row_mapping.i2; // row index in src1 - const char * src0_row = (const char *) src0->data + src0_offset; + const int64_t i1 = id; // selected expert index + const int64_t i2 = i12; // row // desc: when src1 is not a contiguous memory block we have to calculate the offset using the strides // if it is, then we have either copied the data to params->wdata and made it contiguous or we are using @@ -11273,25 +11273,26 @@ static void ggml_compute_forward_mul_mat_id( // TODO: this is a bit of a hack, we should probably have a better way to handle this const char * src1_col = (const char *) wdata + (src1_cont || src1->type != vec_dot_type - ? (i11 + i12*ne11 + i13*ne12*ne11)*row_size - : (i11*nb11 + i12*nb12 + i13*nb13)); + ? (i11 + i12*ne11)*row_size + : (i11*nb11 + i12*nb12)); - float * dst_col = (float *) ((char *) dst->data + (i1*nb1 + i2*nb2 + i3*nb3)); + float * dst_col = (float *) ((char *) dst->data + (i1*nb1 + i2*nb2)); //for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir011; ++ir0) { // vec_dot(ne00, &dst_col[ir0], src0_row + ir0*nb01, src1_col); //} for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir011; ++ir0) { - vec_dot(ne00, &tmp[ir0 - iir0], 0, src0_row + ir0*nb01, 0, src1_col, 0, 1); + vec_dot(ne00, &tmp[ir0 - iir0], 0, src0_cur + ir0*nb01, 0, src1_col, 0, 1); } + memcpy(&dst_col[iir0], tmp, (MIN(iir0 + blck_0, ir011) - iir0)*sizeof(float)); } } } } - #undef MMID_MATRIX_ROW +#undef MMID_MATRIX_ROW } // ggml_compute_forward_out_prod @@ -18830,7 +18831,7 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa const int n_as = src0->ne[2]; cur += GGML_PAD(cur, sizeof(int64_t)); // align cur += n_as * sizeof(int64_t); // matrix_row_counts - cur += n_as * src1->ne[1] * sizeof(int64_t); // matrix_rows + cur += n_as * src1->ne[2] * sizeof(int64_t); // matrix_rows } break; case GGML_OP_OUT_PROD: { @@ -21262,12 +21263,12 @@ struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_p ok = ok && cur != NULL; - ggml_set_name(cur, ctx->infos[i].name.data); - if (!ok) { break; } + ggml_set_name(cur, ctx->infos[i].name.data); + // point the data member to the appropriate location in the binary blob using the tensor infos if (!params.no_alloc) { //cur->data = (char *) data->data + ctx->infos[i].offset - ctx->offset; // offset from start of file diff --git a/ggml.h b/ggml.h index db14c97d5..f065988c9 100644 --- a/ggml.h +++ b/ggml.h @@ -1162,13 +1162,11 @@ extern "C" { enum ggml_prec prec); // indirect matrix multiplication - // ggml_mul_mat_id(ctx, as, ids, id, b) ~= ggml_mul_mat(as[ids[id]], b) GGML_API struct ggml_tensor * ggml_mul_mat_id( struct ggml_context * ctx, struct ggml_tensor * as, - struct ggml_tensor * ids, - int id, - struct ggml_tensor * b); + struct ggml_tensor * b, + struct ggml_tensor * ids); // A: m columns, n rows, // B: p columns, n rows, diff --git a/gguf-py/README.md b/gguf-py/README.md index 22d7ffa52..a04c22759 100644 --- a/gguf-py/README.md +++ b/gguf-py/README.md @@ -21,6 +21,8 @@ pip install gguf [scripts/gguf-convert-endian.py](https://github.com/ggerganov/llama.cpp/blob/master/gguf-py/scripts/gguf-convert-endian.py) — Allows converting the endianness of GGUF files. +[scripts/gguf-new-metadata.py](https://github.com/ggerganov/llama.cpp/blob/master/gguf-py/scripts/gguf-new-metadata.py) — Copies a GGUF file with added/modified/removed metadata values. + ## Development Maintainers who participate in development of this package are advised to install it in editable mode: diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 4b0b6c4c6..ba24065a8 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -90,6 +90,8 @@ class Keys: HF_JSON = "tokenizer.huggingface.json" RWKV = "tokenizer.rwkv.world" CHAT_TEMPLATE = "tokenizer.chat_template" + CHAT_TEMPLATE_N = "tokenizer.chat_template.{name}" + CHAT_TEMPLATES = "tokenizer.chat_templates" # FIM/Infill special tokens constants PREFIX_ID = "tokenizer.ggml.prefix_token_id" SUFFIX_ID = "tokenizer.ggml.suffix_token_id" @@ -133,6 +135,7 @@ class MODEL_ARCH(IntEnum): XVERSE = auto() COMMAND_R = auto() DBRX = auto() + OLMO = auto() class MODEL_TENSOR(IntEnum): @@ -208,6 +211,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = { MODEL_ARCH.XVERSE: "xverse", MODEL_ARCH.COMMAND_R: "command-r", MODEL_ARCH.DBRX: "dbrx", + MODEL_ARCH.OLMO: "olmo", } TENSOR_NAMES: dict[MODEL_TENSOR, str] = { @@ -693,6 +697,17 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = { MODEL_TENSOR.FFN_DOWN_EXP, MODEL_TENSOR.FFN_UP_EXP, ], + MODEL_ARCH.OLMO: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.ATTN_Q, + MODEL_TENSOR.ATTN_K, + MODEL_TENSOR.ATTN_V, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.FFN_GATE, + MODEL_TENSOR.FFN_DOWN, + MODEL_TENSOR.FFN_UP, + ], # TODO } diff --git a/gguf-py/gguf/gguf_writer.py b/gguf-py/gguf/gguf_writer.py index ff9326d59..e3dbca454 100644 --- a/gguf-py/gguf/gguf_writer.py +++ b/gguf-py/gguf/gguf_writer.py @@ -6,7 +6,8 @@ import struct import tempfile from enum import Enum, auto from io import BufferedWriter -from typing import IO, Any, Sequence +from typing import IO, Any, Sequence, Mapping +from string import ascii_letters, digits import numpy as np @@ -466,7 +467,33 @@ class GGUFWriter: def add_add_space_prefix(self, value: bool) -> None: self.add_bool(Keys.Tokenizer.ADD_PREFIX, value) - def add_chat_template(self, value: str) -> None: + def add_chat_template(self, value: str | Sequence[Mapping[str, str]]) -> None: + if isinstance(value, list): + template_default = None + template_names = set() + + for choice in value: + name = choice.get('name', '') + template = choice.get('template') + + # Allowing non-alphanumerical characters in template name is probably not a good idea, so filter it + name = ''.join((c if c in ascii_letters + digits else '_' for c in name)) + + if name and template is not None: + if name == 'default': + template_default = template + else: + template_names.add(name) + self.add_string(Keys.Tokenizer.CHAT_TEMPLATE_N.format(name=name), template) + + if template_names: + self.add_array(Keys.Tokenizer.CHAT_TEMPLATES, list(template_names)) + + if template_default is None: + return + + value = template_default + self.add_string(Keys.Tokenizer.CHAT_TEMPLATE, value) def add_prefix_token_id(self, id: int) -> None: diff --git a/gguf-py/gguf/vocab.py b/gguf-py/gguf/vocab.py index a23136b18..378eaecad 100644 --- a/gguf-py/gguf/vocab.py +++ b/gguf-py/gguf/vocab.py @@ -141,7 +141,7 @@ class SpecialVocab: with open(tokenizer_config_file, encoding = 'utf-8') as f: tokenizer_config = json.load(f) chat_template = tokenizer_config.get('chat_template') - if chat_template is None or isinstance(chat_template, str): + if chat_template is None or isinstance(chat_template, (str, list)): self.chat_template = chat_template else: print( diff --git a/gguf-py/pyproject.toml b/gguf-py/pyproject.toml index 13cbfffbc..d1d876d6d 100644 --- a/gguf-py/pyproject.toml +++ b/gguf-py/pyproject.toml @@ -33,3 +33,4 @@ build-backend = "poetry.core.masonry.api" gguf-convert-endian = "scripts:gguf_convert_endian_entrypoint" gguf-dump = "scripts:gguf_dump_entrypoint" gguf-set-metadata = "scripts:gguf_set_metadata_entrypoint" +gguf-new-metadata = "scripts:gguf_new_metadata_entrypoint" diff --git a/gguf-py/scripts/__init__.py b/gguf-py/scripts/__init__.py index 77132db7a..1ad45639a 100644 --- a/gguf-py/scripts/__init__.py +++ b/gguf-py/scripts/__init__.py @@ -8,5 +8,6 @@ os.environ["NO_LOCAL_GGUF"] = "TRUE" gguf_convert_endian_entrypoint = import_module("scripts.gguf-convert-endian").main gguf_dump_entrypoint = import_module("scripts.gguf-dump").main gguf_set_metadata_entrypoint = import_module("scripts.gguf-set-metadata").main +gguf_new_metadata_entrypoint = import_module("scripts.gguf-new-metadata").main del import_module, os diff --git a/gguf-py/scripts/gguf-new-metadata.py b/gguf-py/scripts/gguf-new-metadata.py new file mode 100644 index 000000000..3444ab418 --- /dev/null +++ b/gguf-py/scripts/gguf-new-metadata.py @@ -0,0 +1,190 @@ +#!/usr/bin/env python3 +import logging +import argparse +import os +import sys +import json +from pathlib import Path + +import numpy as np +from typing import Any, Mapping, Sequence + +# Necessary to load the local gguf package +if "NO_LOCAL_GGUF" not in os.environ and (Path(__file__).parent.parent.parent / 'gguf-py').exists(): + sys.path.insert(0, str(Path(__file__).parent.parent)) + +import gguf + +logger = logging.getLogger("gguf-new-metadata") + + +def get_byteorder(reader: gguf.GGUFReader) -> gguf.GGUFEndian: + if np.uint32(1) == np.uint32(1).newbyteorder("<"): + # Host is little endian + host_endian = gguf.GGUFEndian.LITTLE + swapped_endian = gguf.GGUFEndian.BIG + else: + # Sorry PDP or other weird systems that don't use BE or LE. + host_endian = gguf.GGUFEndian.BIG + swapped_endian = gguf.GGUFEndian.LITTLE + + if reader.byte_order == "S": + return swapped_endian + else: + return host_endian + + +def decode_field(field: gguf.ReaderField) -> Any: + if field and field.types: + main_type = field.types[0] + + if main_type == gguf.GGUFValueType.ARRAY: + sub_type = field.types[-1] + + if sub_type == gguf.GGUFValueType.STRING: + return [str(bytes(field.parts[idx]), encoding='utf8') for idx in field.data] + else: + return [pv for idx in field.data for pv in field.parts[idx].tolist()] + if main_type == gguf.GGUFValueType.STRING: + return str(bytes(field.parts[-1]), encoding='utf8') + else: + return field.parts[-1][0] + + return None + + +def get_field_data(reader: gguf.GGUFReader, key: str) -> Any: + field = reader.get_field(key) + + return decode_field(field) + + +def copy_with_new_metadata(reader: gguf.GGUFReader, writer: gguf.GGUFWriter, new_metadata: Mapping[str, str], remove_metadata: Sequence[str]) -> None: + for field in reader.fields.values(): + # Suppress virtual fields and fields written by GGUFWriter + if field.name == gguf.Keys.General.ARCHITECTURE or field.name.startswith('GGUF.'): + logger.debug(f'Suppressing {field.name}') + continue + + # Skip old chat templates if we have new ones + if field.name.startswith(gguf.Keys.Tokenizer.CHAT_TEMPLATE) and gguf.Keys.Tokenizer.CHAT_TEMPLATE in new_metadata: + logger.debug(f'Skipping {field.name}') + continue + + if field.name in remove_metadata: + logger.debug(f'Removing {field.name}') + continue + + old_val = decode_field(field) + val = new_metadata.get(field.name, old_val) + + if field.name in new_metadata: + logger.debug(f'Modifying {field.name}: "{old_val}" -> "{val}"') + del new_metadata[field.name] + elif val is not None: + logger.debug(f'Copying {field.name}') + + if val is not None: + writer.add_key(field.name) + writer.add_val(val, field.types[0]) + + if gguf.Keys.Tokenizer.CHAT_TEMPLATE in new_metadata: + logger.debug('Adding chat template(s)') + writer.add_chat_template(new_metadata[gguf.Keys.Tokenizer.CHAT_TEMPLATE]) + del new_metadata[gguf.Keys.Tokenizer.CHAT_TEMPLATE] + + # TODO: Support other types than string? + for key, val in new_metadata.items(): + logger.debug(f'Adding {key}: {val}') + writer.add_key(key) + writer.add_val(val, gguf.GGUFValueType.STRING) + + for tensor in reader.tensors: + # Dimensions are written in reverse order, so flip them first + shape = np.flipud(tensor.shape) + writer.add_tensor_info(tensor.name, shape, tensor.data.dtype, tensor.data.nbytes, tensor.tensor_type) + + writer.write_header_to_file() + writer.write_kv_data_to_file() + writer.write_ti_data_to_file() + + for tensor in reader.tensors: + writer.write_tensor_data(tensor.data) + + writer.close() + + +def main() -> None: + parser = argparse.ArgumentParser(description="Make a copy of a GGUF file with new metadata") + parser.add_argument("input", type=Path, help="GGUF format model input filename") + parser.add_argument("output", type=Path, help="GGUF format model output filename") + parser.add_argument("--general-name", type=str, help="The models general.name") + parser.add_argument("--general-description", type=str, help="The models general.description") + parser.add_argument("--chat-template", type=str, help="Chat template string (or JSON string containing templates)") + parser.add_argument("--chat-template-config", type=Path, help="Config file (tokenizer_config.json) containing chat template(s)") + parser.add_argument("--remove-metadata", action="append", type=str, help="Remove metadata (by key name) from output model") + parser.add_argument("--force", action="store_true", help="Bypass warnings without confirmation") + parser.add_argument("--verbose", action="store_true", help="Increase output verbosity") + args = parser.parse_args(None if len(sys.argv) > 2 else ["--help"]) + + logging.basicConfig(level=logging.DEBUG if args.verbose else logging.INFO) + + new_metadata = {} + remove_metadata = args.remove_metadata or [] + + if args.general_name: + new_metadata[gguf.Keys.General.NAME] = args.general_name + + if args.general_description: + new_metadata[gguf.Keys.General.DESCRIPTION] = args.general_description + + if args.chat_template: + new_metadata[gguf.Keys.Tokenizer.CHAT_TEMPLATE] = json.loads(args.chat_template) if args.chat_template.startswith('[') else args.chat_template + + if args.chat_template_config: + with open(args.chat_template_config, 'r') as fp: + config = json.load(fp) + template = config.get('chat_template') + if template: + new_metadata[gguf.Keys.Tokenizer.CHAT_TEMPLATE] = template + + if remove_metadata: + logger.warning('*** Warning *** Warning *** Warning **') + logger.warning('* Most metadata is required for a fully functional GGUF file,') + logger.warning('* removing crucial metadata may result in a corrupt output file!') + + if not args.force: + logger.warning('* Enter exactly YES if you are positive you want to proceed:') + response = input('YES, I am sure> ') + if response != 'YES': + logger.info("You didn't enter YES. Okay then, see ya!") + sys.exit(0) + + logger.info(f'* Loading: {args.input}') + reader = gguf.GGUFReader(args.input, 'r') + + arch = get_field_data(reader, gguf.Keys.General.ARCHITECTURE) + endianess = get_byteorder(reader) + + if os.path.isfile(args.output) and not args.force: + logger.warning('*** Warning *** Warning *** Warning **') + logger.warning(f'* The "{args.output}" GGUF file already exists, it will be overwritten!') + logger.warning('* Enter exactly YES if you are positive you want to proceed:') + response = input('YES, I am sure> ') + if response != 'YES': + logger.info("You didn't enter YES. Okay then, see ya!") + sys.exit(0) + + logger.info(f'* Writing: {args.output}') + writer = gguf.GGUFWriter(args.output, arch=arch, endianess=endianess) + + alignment = get_field_data(reader, gguf.Keys.General.ALIGNMENT) + if alignment is not None: + logger.debug(f'Setting custom alignment: {alignment}') + writer.data_alignment = alignment + + copy_with_new_metadata(reader, writer, new_metadata, remove_metadata) + + +if __name__ == '__main__': + main() diff --git a/llama.cpp b/llama.cpp index d828dd786..add75818b 100644 --- a/llama.cpp +++ b/llama.cpp @@ -221,6 +221,7 @@ enum llm_arch { LLM_ARCH_XVERSE, LLM_ARCH_COMMAND_R, LLM_ARCH_DBRX, + LLM_ARCH_OLMO, LLM_ARCH_UNKNOWN, }; @@ -255,6 +256,7 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_XVERSE, "xverse" }, { LLM_ARCH_COMMAND_R, "command-r" }, { LLM_ARCH_DBRX, "dbrx" }, + { LLM_ARCH_OLMO, "olmo" }, { LLM_ARCH_UNKNOWN, "(unknown)" }, }; @@ -989,6 +991,20 @@ static const std::map> LLM_TENSOR_NA { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" }, }, }, + { + LLM_ARCH_OLMO, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + }, + }, { LLM_ARCH_UNKNOWN, { @@ -4070,6 +4086,18 @@ static void llm_load_hparams( default: model.type = e_model::MODEL_UNKNOWN; } } break; + case LLM_ARCH_OLMO: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); + ml.get_key(LLM_KV_ATTENTION_CLAMP_KQV, hparams.f_clamp_kqv, false); + + switch (hparams.n_layer) { + case 22: model.type = e_model::MODEL_1B; break; + case 32: model.type = e_model::MODEL_7B; break; + case 80: model.type = e_model::MODEL_70B; break; + default: model.type = e_model::MODEL_UNKNOWN; + } + } break; default: (void)0; } @@ -4495,6 +4523,13 @@ static bool llm_load_tensors( auto & hparams = model.hparams; +#ifdef GGML_USE_SYCL + // disable MoE with SYCL until mul_mat_id is updated + if (hparams.n_expert > 0) { + n_gpu_layers = 0; + } +#endif + model.split_mode = split_mode; model.main_gpu = main_gpu; model.n_gpu_layers = n_gpu_layers; @@ -5659,6 +5694,37 @@ static bool llm_load_tensors( layer.wv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}); layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}); + layer.ffn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}); + layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}); + layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}); + } + } break; + case LLM_ARCH_OLMO: // adapted from LLM_ARCH_LLAMA with norm params removed + { + model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); + + // output + { + model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, false); + // if output is NULL, init from the input tok embed + if (model.output == NULL) { + model.output = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); + ml.n_created--; // artificial tensor + ml.size_data += ggml_nbytes(model.output); + } + } + + for (int i = 0; i < n_layer; ++i) { + ggml_context * ctx_split = ctx_for_layer_split(i); + + auto & layer = model.layers[i]; + + layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}); + layer.wk = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}); + layer.wv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}); + layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}); + + layer.ffn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}); layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}); layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}); @@ -6109,6 +6175,100 @@ static struct ggml_tensor * llm_build_ffn( return cur; } +static struct ggml_tensor * llm_build_moe_ffn( + struct ggml_context * ctx, + struct ggml_tensor * cur, + struct ggml_tensor * gate_inp, + struct ggml_tensor * up_exps, + struct ggml_tensor * gate_exps, + struct ggml_tensor * down_exps, + int64_t n_expert, + int64_t n_expert_used, + llm_ffn_op_type type_op, + bool norm_w, + const llm_build_cb & cb, + int il) { + int64_t n_embd = cur->ne[0]; + int64_t n_tokens = cur->ne[1]; + + ggml_tensor * logits = ggml_mul_mat(ctx, gate_inp, cur); // [n_expert, n_tokens] + cb(logits, "ffn_moe_logits", il); + + ggml_tensor * probs = ggml_soft_max(ctx, logits); // [n_expert, n_tokens] + cb(probs, "ffn_moe_probs", il); + + // select experts + ggml_tensor * selected_experts = ggml_top_k(ctx, probs, n_expert_used); // [n_expert_used, n_tokens] + cb(selected_experts->src[0], "ffn_moe_argsort", il); + cb(selected_experts, "ffn_moe_topk", il); + + ggml_tensor * weights = ggml_get_rows(ctx, + ggml_reshape_3d(ctx, probs, 1, n_expert, n_tokens), selected_experts); // [1, n_expert_used, n_tokens] + cb(weights, "ffn_moe_weights", il); + + if (norm_w) { + weights = ggml_reshape_2d(ctx, weights, n_expert_used, n_tokens); + + ggml_tensor * weights_sum = ggml_sum_rows(ctx, weights); // [1, n_tokens] + cb(weights_sum, "ffn_moe_weights_sum", il); + + weights = ggml_div(ctx, weights, weights_sum); // [n_expert_used, n_tokens] + cb(weights, "ffn_moe_weights_norm", il); + + weights = ggml_reshape_3d(ctx, weights, 1, n_expert_used, n_tokens); + } + + cur = ggml_reshape_3d(ctx, cur, n_embd, 1, n_tokens); + ggml_tensor * up = ggml_mul_mat_id(ctx, up_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens] + cb(up, "ffn_moe_up", il); + + ggml_tensor * gate = ggml_mul_mat_id(ctx, gate_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens] + cb(gate, "ffn_moe_gate", il); + + switch (type_op) { + case LLM_FFN_SILU: + { + gate = ggml_silu(ctx, gate); + cb(gate, "ffn_moe_silu", il); + } break; + case LLM_FFN_GELU: + { + gate = ggml_gelu(ctx, gate); + cb(gate, "ffn_moe_gelu", il); + } break; + default: + GGML_ASSERT(false); + } + + ggml_tensor * par = ggml_mul(ctx, up, gate); // [n_ff, n_expert_used, n_tokens] + cb(par, "ffn_moe_gate_par", il); + + ggml_tensor * experts = ggml_mul_mat_id(ctx, down_exps, par, selected_experts); // [n_embd, n_expert_used, n_tokens] + cb(experts, "ffn_moe_down", il); + + experts = ggml_mul(ctx, experts, weights); + + // aggregate experts + ggml_tensor * moe_out = nullptr; + for (int i = 0; i < n_expert_used; ++i) { + ggml_tensor * cur_expert = ggml_view_2d(ctx, experts, n_embd, n_tokens, + experts->nb[2], i*experts->nb[1]); + + if (i == 0) { + moe_out = cur_expert; + } else { + moe_out = ggml_add(ctx, moe_out, cur_expert); + } + } + + if (n_expert_used == 1) { + // avoid returning a non-contiguous tensor + moe_out = ggml_cont(ctx, moe_out); + } + + return moe_out; +} + // if max_alibi_bias > 0 then apply ALiBi static struct ggml_tensor * llm_build_kqv( struct ggml_context * ctx, @@ -6698,7 +6858,15 @@ struct llm_build_context { LLM_NORM_RMS, cb, il); cb(cur, "ffn_norm", il); - cur = build_moe_ffn(cur, n_tokens, LLM_FFN_SILU, true, il); + cur = llm_build_moe_ffn(ctx0, cur, + model.layers[il].ffn_gate_inp, + model.layers[il].ffn_up_exps, + model.layers[il].ffn_gate_exps, + model.layers[il].ffn_down_exps, + n_expert, n_expert_used, + LLM_FFN_SILU, true, + cb, il); + cb(cur, "ffn_moe_out", il); } cur = ggml_add(ctx0, cur, ffn_inp); @@ -6730,80 +6898,6 @@ struct llm_build_context { return gf; } - // REVIEW: will be replaced by https://github.com/ggerganov/llama.cpp/pull/6505 - ggml_tensor * build_moe_ffn(ggml_tensor * cur, int32_t n_tokens, llm_ffn_op_type type_op, bool norm_w, int il) { - ggml_tensor * logits = ggml_mul_mat(ctx0, model.layers[il].ffn_gate_inp, cur); // [n_tokens, num_experts] - cb(logits, "ffn_moe_logits", il); - - ggml_tensor * probs = ggml_soft_max(ctx0, logits); // [n_tokens, num_experts] - cb(probs, "ffn_moe_probs", il); - - // select experts - ggml_tensor * selected_experts = ggml_top_k(ctx0, probs, n_expert_used); // [n_tokens, num_experts_per_tok] - cb(selected_experts->src[0], "ffn_moe_argsort", il); - - ggml_tensor * weights = ggml_get_rows(ctx0, - ggml_reshape_3d(ctx0, probs, 1, n_expert, n_tokens), selected_experts); - cb(weights, "ffn_moe_weights", il); - - weights = ggml_reshape_2d(ctx0, weights, n_expert_used, n_tokens); // [n_tokens, num_experts_per_tok] - - if (norm_w) { - ggml_tensor * weights_sum = ggml_sum_rows(ctx0, weights); - cb(weights_sum, "ffn_moe_weights_sum", il); - - weights = ggml_div(ctx0, weights, weights_sum); // [n_tokens, num_experts_per_tok] - cb(weights, "ffn_moe_weights_norm", il); - } - - // compute expert outputs - ggml_tensor * moe_out = nullptr; - - for (int i = 0; i < n_expert_used; ++i) { - ggml_tensor * cur_expert; - - ggml_tensor * cur_up = ggml_mul_mat_id(ctx0, model.layers[il].ffn_up_exps, selected_experts, i, cur); - cb(cur_up, "ffn_moe_up", il); - - ggml_tensor * gate = ggml_mul_mat_id(ctx0, model.layers[il].ffn_gate_exps, selected_experts, i, cur); - cb(gate, "ffn_moe_gate", il); - - switch (type_op) { - case LLM_FFN_SILU: - { - gate = ggml_silu(ctx0, gate); - cb(gate, "ffn_moe_silu", il); - } break; - case LLM_FFN_GELU: - { - gate = ggml_gelu(ctx0, gate); - cb(gate, "ffn_moe_gelu", il); - } break; - default: - GGML_ASSERT(false); - } - - cur_expert = ggml_mul(ctx0, cur_up, gate); - cb(cur_expert, "ffn_moe_gate_par", il); - - cur_expert = ggml_mul_mat_id(ctx0, model.layers[il].ffn_down_exps, selected_experts, i, cur_expert); // [n_tokens, n_embd] - cb(cur_expert, "ffn_moe_down", il); - - cur_expert = ggml_mul(ctx0, cur_expert, - ggml_view_2d(ctx0, weights, 1, n_tokens, weights->nb[1], i*weights->nb[0])); - cb(cur_expert, "ffn_moe_weighted", il); - - if (i == 0) { - moe_out = cur_expert; - } else { - moe_out = ggml_add(ctx0, moe_out, cur_expert); - cb(moe_out, "ffn_moe_out", il); - } - } - - return moe_out; - } - struct ggml_cgraph * build_baichuan() { struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false); @@ -7251,7 +7345,15 @@ struct llm_build_context { LLM_NORM_RMS, cb, il); cb(cur, "ffn_norm", il); - cur = build_moe_ffn(cur, n_tokens, LLM_FFN_GELU, true, il); + cur = llm_build_moe_ffn(ctx0, cur, + model.layers[il].ffn_gate_inp, + model.layers[il].ffn_up_exps, + model.layers[il].ffn_gate_exps, + model.layers[il].ffn_down_exps, + n_expert, n_expert_used, + LLM_FFN_GELU, true, + cb, il); + cb(cur, "ffn_moe_out", il); // Grok // if layer_out_norm is present then apply it before adding the input @@ -7263,7 +7365,6 @@ struct llm_build_context { cb(cur, "layer_out_norm", il); } - cur = ggml_add(ctx0, cur, ffn_inp); cb(cur, "ffn_out", il); @@ -7387,7 +7488,15 @@ struct llm_build_context { LLM_NORM, cb, il); cb(cur, "attn_out_norm", il); - cur = build_moe_ffn(cur, n_tokens, LLM_FFN_SILU, true, il); + cur = llm_build_moe_ffn(ctx0, cur, + model.layers[il].ffn_gate_inp, + model.layers[il].ffn_up_exps, + model.layers[il].ffn_gate_exps, + model.layers[il].ffn_down_exps, + n_expert, n_expert_used, + LLM_FFN_SILU, true, + cb, il); + cb(cur, "ffn_moe_out", il); cur = ggml_add(ctx0, cur, ffn_inp); cb(cur, "ffn_out", il); @@ -8559,12 +8668,6 @@ struct llm_build_context { Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); cb(Vcur, "Vcur", il); - // these nodes are added to the graph together so that they are not reordered - // by doing so, the number of splits in the graph is reduced - ggml_build_forward_expand(gf, Qcur); - ggml_build_forward_expand(gf, Kcur); - ggml_build_forward_expand(gf, Vcur); - Qcur = ggml_rope_custom( ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, @@ -8715,7 +8818,16 @@ struct llm_build_context { LLM_NORM_RMS, cb, il); cb(cur, "ffn_norm", il); - ggml_tensor * moe_out = build_moe_ffn(cur, n_tokens, LLM_FFN_SILU, false, il); + ggml_tensor * moe_out = + llm_build_moe_ffn(ctx0, cur, + model.layers[il].ffn_gate_inp, + model.layers[il].ffn_up_exps, + model.layers[il].ffn_gate_exps, + model.layers[il].ffn_down_exps, + n_expert, n_expert_used, + LLM_FFN_SILU, false, + cb, il); + cb(cur, "ffn_moe_out", il); // FFN shared expert { @@ -10100,6 +10212,139 @@ struct llm_build_context { return gf; } + + // ref: https://allenai.org/olmo + // based on the original build_llama() function, changes: + // * non-parametric layer norm + // * clamp qkv + // * removed bias + // * removed MoE + struct ggml_cgraph * build_olmo() { + struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false); + + // mutable variable, needed during the last layer of the computation to skip unused tokens + int32_t n_tokens = this->n_tokens; + + const int64_t n_embd_head = hparams.n_embd_head_v; + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + GGML_ASSERT(n_embd_head == hparams.n_rot); + + struct ggml_tensor * cur; + struct ggml_tensor * inpL; + + inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb); + + // inp_pos - contains the positions + struct ggml_tensor * inp_pos = build_inp_pos(); + + // KQ_mask (mask for 1 head, it will be broadcasted to all heads) + struct ggml_tensor * KQ_mask = build_inp_KQ_mask(); + + for (int il = 0; il < n_layer; ++il) { + struct ggml_tensor * inpSA = inpL; + + // norm + cur = llm_build_norm(ctx0, inpL, hparams, + NULL, NULL, + LLM_NORM, cb, il); + cb(cur, "attn_norm", il); + + // self-attention + { + // compute Q and K and RoPE them + struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, model.layers[il].wq, cur); + cb(Qcur, "Qcur", il); + if (hparams.f_clamp_kqv > 0.0f) { + Qcur = ggml_clamp(ctx0, Qcur, -hparams.f_clamp_kqv, hparams.f_clamp_kqv); + cb(Qcur, "Qcur", il); + } + + struct ggml_tensor * Kcur = ggml_mul_mat(ctx0, model.layers[il].wk, cur); + cb(Kcur, "Kcur", il); + if (hparams.f_clamp_kqv > 0.0f) { + Kcur = ggml_clamp(ctx0, Kcur, -hparams.f_clamp_kqv, hparams.f_clamp_kqv); + cb(Kcur, "Kcur", il); + } + + struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, model.layers[il].wv, cur); + cb(Vcur, "Vcur", il); + if (hparams.f_clamp_kqv > 0.0f) { + Vcur = ggml_clamp(ctx0, Vcur, -hparams.f_clamp_kqv, hparams.f_clamp_kqv); + cb(Vcur, "Vcur", il); + } + + Qcur = ggml_rope_custom( + ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, + n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + cb(Qcur, "Qcur", il); + + Kcur = ggml_rope_custom( + ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, + n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + cb(Kcur, "Kcur", il); + + cur = llm_build_kv(ctx0, model, hparams, kv_self, gf, + model.layers[il].wo, nullptr, + Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + } + + if (il == n_layer - 1) { + // skip computing output for unused tokens + struct ggml_tensor * inp_out_ids = build_inp_out_ids(); + n_tokens = n_outputs; + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + + struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + + // feed-forward network + cur = llm_build_norm(ctx0, ffn_inp, hparams, + NULL, NULL, + LLM_NORM, cb, il); + cb(cur, "ffn_norm", il); + + cur = llm_build_ffn(ctx0, cur, + model.layers[il].ffn_up, NULL, + model.layers[il].ffn_gate, NULL, + model.layers[il].ffn_down, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, cb, il); + cb(cur, "ffn_out", il); + + cur = ggml_add(ctx0, cur, ffn_inp); + cb(cur, "ffn_out", il); + + ggml_tensor * layer_dir = lctx.cvec.tensor_for(il); + if (layer_dir != nullptr) { + cur = ggml_add(ctx0, cur, layer_dir); + } + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + cur = inpL; + + cur = llm_build_norm(ctx0, cur, hparams, + NULL, NULL, + LLM_NORM, cb, -1); + cb(cur, "result_norm", -1); + + // lm_head + cur = ggml_mul_mat(ctx0, model.output, cur); + cb(cur, "result_output", -1); + + ggml_build_forward_expand(gf, cur); + + return gf; + } }; static struct ggml_cgraph * llama_build_graph_defrag(llama_context & lctx, const std::vector & ids) { @@ -10305,6 +10550,10 @@ static struct ggml_cgraph * llama_build_graph( { result = llm.build_dbrx(); } break; + case LLM_ARCH_OLMO: + { + result = llm.build_olmo(); + } break; default: GGML_ASSERT(false); } @@ -15167,6 +15416,7 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) { case LLM_ARCH_MINICPM: case LLM_ARCH_XVERSE: case LLM_ARCH_COMMAND_R: + case LLM_ARCH_OLMO: return LLAMA_ROPE_TYPE_NORM; // the pairs of head values are offset by n_rot/2 diff --git a/scripts/compare-commits.sh b/scripts/compare-commits.sh index d1272506c..fd0ee88b2 100755 --- a/scripts/compare-commits.sh +++ b/scripts/compare-commits.sh @@ -12,19 +12,7 @@ bench_args="${@:3}" rm -f llama-bench.sqlite -backend="cpu" - -if [[ "$OSTYPE" == "darwin"* ]]; then - backend="metal" -elif command -v nvcc &> /dev/null; then - backend="cuda" -fi - -make_opts="" - -if [[ "$backend" == "cuda" ]]; then - make_opts="LLAMA_CUDA=1" -fi +# to test a backend, call the script with the corresponding environment variable (e.g. LLAMA_CUDA=1 ./scripts/compare-commits.sh ...) git checkout $1 make clean && make -j32 $make_opts llama-bench diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 0b245fd39..1cc5e53bd 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -101,7 +101,7 @@ static std::vector tensor_to_float(const ggml_tensor * t) { } else if (t->type == GGML_TYPE_I8) { tv.push_back((float)*(int8_t *) &buf[i]); } else if (quantized) { - tt.to_float(&buf[i], vq.data(), ggml_blck_size(t->type)); + tt.to_float(&buf[i], vq.data(), bs); tv.insert(tv.end(), vq.begin(), vq.end()); } else { GGML_ASSERT(false); @@ -958,14 +958,14 @@ struct test_mul_mat_id : public test_case { const ggml_type type_a; const ggml_type type_b; const int n_mats; - const int id; + const int n_used; + const bool b; // brodcast b matrix const int64_t m; const int64_t n; const int64_t k; - const bool v; // view (non-contiguous ids) std::string vars() override { - return VARS_TO_STR8(type_a, type_b, n_mats, id, m, n, k, v); + return VARS_TO_STR8(type_a, type_b, n_mats, n_used, b, m, n, k); } double max_nmse_err() override { @@ -982,20 +982,22 @@ struct test_mul_mat_id : public test_case { } test_mul_mat_id(ggml_type type_a = GGML_TYPE_F32, ggml_type type_b = GGML_TYPE_F32, - int n_mats = 2, int id = 0, - int64_t m = 32, int64_t n = 32, int64_t k = 32, bool v = false) - : type_a(type_a), type_b(type_b), n_mats(n_mats), id(id), - m(m), n(n), k(k), v(v) {} + int n_mats = 8, int n_used = 2, bool b = false, + int64_t m = 32, int64_t n = 32, int64_t k = 32) + : type_a(type_a), type_b(type_b), n_mats(n_mats), n_used(n_used), b(b), + m(m), n(n), k(k) { + GGML_ASSERT(n_used <= n_mats); + } ggml_tensor * build_graph(ggml_context * ctx) override { // C^T = A * B^T: (k, m) * (k, n) => (m, n) - ggml_tensor * mats = ggml_new_tensor_3d(ctx, type_a, k, m, n_mats); + ggml_tensor * as = ggml_new_tensor_3d(ctx, type_a, k, m, n_mats); ggml_tensor * ids = ggml_new_tensor_2d(ctx, GGML_TYPE_I32, n_mats, n); - if (v) { - ids = ggml_view_2d(ctx, ids, n_mats/2, ids->ne[1], ids->nb[1], 0); + if (n_used != n_mats) { + ids = ggml_view_2d(ctx, ids, n_used, n, ids->nb[1], 0); } - ggml_tensor * b = ggml_new_tensor_2d(ctx, type_b, k, n); - ggml_tensor * out = ggml_mul_mat_id(ctx, mats, ids, v ? id/2 : id, b); + ggml_tensor * b = ggml_new_tensor_3d(ctx, type_b, k, this->b ? 1 : n_used, n); + ggml_tensor * out = ggml_mul_mat_id(ctx, as, b, ids); return out; } @@ -1692,7 +1694,6 @@ public: } }; - // Llama struct test_llama : public test_llm { static constexpr float freq_base = 10000.0f; @@ -1956,6 +1957,25 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op GGML_TYPE_IQ4_NL, GGML_TYPE_IQ3_S, GGML_TYPE_IQ4_XS, }; + const ggml_type base_types[] = { + GGML_TYPE_F32, GGML_TYPE_F16, + GGML_TYPE_Q4_0, + GGML_TYPE_Q4_K, + GGML_TYPE_IQ2_XXS + }; + + const ggml_type other_types[] = { + GGML_TYPE_Q4_1, + GGML_TYPE_Q5_0, GGML_TYPE_Q5_1, + GGML_TYPE_Q8_0, + GGML_TYPE_Q2_K, GGML_TYPE_Q3_K, + GGML_TYPE_Q5_K, + GGML_TYPE_Q6_K, + GGML_TYPE_IQ2_XS, GGML_TYPE_IQ2_S, + GGML_TYPE_IQ3_XXS, GGML_TYPE_IQ1_S, GGML_TYPE_IQ1_M, + GGML_TYPE_IQ4_NL, GGML_TYPE_IQ3_S, GGML_TYPE_IQ4_XS, + }; + // unary ops for (int op = 0; op < GGML_UNARY_OP_COUNT; op++) { test_cases.emplace_back(new test_unary((ggml_unary_op) op)); @@ -2064,7 +2084,7 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op test_cases.emplace_back(new test_rms_norm(GGML_TYPE_F32, {64, 10, 10, 10}, eps)); } - for (ggml_type type_a : all_types) { + for (ggml_type type_a : base_types) { for (ggml_type type_b : {GGML_TYPE_F32, GGML_TYPE_F16}) { test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, { 1, 1}, {1, 1})); test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {10, 1}, {1, 1})); @@ -2084,6 +2104,12 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op } } + for (ggml_type type_a : other_types) { + for (ggml_type type_b : {GGML_TYPE_F32}) { + test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, { 1, 1}, {1, 1})); + } + } + test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 64, 2, 128, { 8, 1}, {1, 1})); test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 83, 2, 128, { 8, 1}, {4, 1})); test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 64, 2, 64, { 8, 1}, {4, 1})); @@ -2091,13 +2117,32 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 64, 45, 128, { 8, 1}, {4, 1})); test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 128, 45, 64, { 8, 1}, {4, 1})); - for (ggml_type type_a : all_types) { + for (ggml_type type_a : base_types) { for (ggml_type type_b : {GGML_TYPE_F32 /*, GGML_TYPE_F16 */}) { - for (int n_mats : {2, 4, 8}) { - for (int id = 0; id < n_mats; id++) { - for (bool v : {false, true}) { - test_cases.emplace_back(new test_mul_mat_id(type_a, type_b, n_mats, id, 16, 1, 256, v)); - test_cases.emplace_back(new test_mul_mat_id(type_a, type_b, n_mats, id, 16, 16, 256, v)); + for (int n_mats : {4, 8}) { + for (int n_used : {1, 2, 4}) { + for (bool b : {false, true}) { + for (int n : {1, 32}) { + int m = 512; + int k = 256; + test_cases.emplace_back(new test_mul_mat_id(type_a, type_b, n_mats, n_used, b, m, n, k)); + } + } + } + } + } + } + + for (ggml_type type_a : other_types) { + for (ggml_type type_b : {GGML_TYPE_F32 /*, GGML_TYPE_F16 */}) { + for (int n_mats : {4}) { + for (int n_used : {2}) { + for (bool b : {false}) { + for (int n : {1}) { + int m = 512; + int k = 256; + test_cases.emplace_back(new test_mul_mat_id(type_a, type_b, n_mats, n_used, b, m, n, k)); + } } } }