diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index e24b8a319..a5813839f 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -581,6 +581,27 @@ extern "C" { GGML_TENSOR_FLAG_LOSS = 8, // ...defines loss for numerical optimization (multiple loss tensors add up) }; + // ggml object + struct ggml_object { + size_t offs; + size_t size; + + struct ggml_object * next; + + enum ggml_object_type type; + + char padding[4]; + }; + + static const size_t GGML_OBJECT_SIZE = sizeof(struct ggml_object); + + enum tensor_parallel_mode { + TENSOR_NO_CHANGE, + TENSOR_SPLIT_BY_ROW, + TENSOR_SPLIT_BY_COLUMN, + TENSOR_KEEPED_ON_MASTER, + } + // n-dimensional tensor struct ggml_tensor { enum ggml_type type; @@ -616,6 +637,8 @@ extern "C" { void * extra; // extra things e.g. for ggml-cuda.cu + enum tensor_parallel_mode split_mode = tensor_parallel_mode::TENSOR_NO_CHANGE; + // char padding[4]; }; diff --git a/ggml/src/ggml-sycl.cpp b/ggml/src/ggml-sycl.cpp index 6978a3192..a9217658b 100644 --- a/ggml/src/ggml-sycl.cpp +++ b/ggml/src/ggml-sycl.cpp @@ -1239,6 +1239,15 @@ static void relu_f32_sycl(const float *x, float *dst, const int k, }); } +static void allreduce_f32_sycl(const float *x, float *dst, const int k, + queue_ptr stream) { + auto dev = ccl::create_device(stream->get_device()); + auto ctx = ccl::create_context(stream->get_context()); + auto comm = dpct::dev_mgr::instance().create_ccl_communicator(dev, ctx); + auto ccl_stream = ccl::create_stream(*stream); + ccl::allreduce(x, dst, k, ccl::reduction::sum, comm, ccl_stream).wait(); +} + static void hardsigmoid_f32_sycl(const float *x, float *dst, const int k, queue_ptr stream) { const int num_blocks = (k + SYCL_HARDSIGMOID_BLOCK_SIZE - 1) / SYCL_HARDSIGMOID_BLOCK_SIZE; @@ -1736,6 +1745,16 @@ void print_device_detail(int id, sycl::device &device, std::string device_type) global_mem_size, device.get_info().c_str()); } +int ggml_backend_sycl_rank() { + // use ccl rank as main gpu + return dpct::dev_mgr::instance().get_ccl_rank(); +} + +int ggml_backend_sycl_world_size() { + // use ccl rank as main gpu + return dpct::dev_mgr::instance().get_ccl_world_size(); +} + void ggml_backend_sycl_print_sycl_devices() { GGML_SYCL_DEBUG("[SYCL] call ggml_backend_sycl_print_sycl_devices\n"); int device_count = dpct::dev_mgr::instance().device_count(); @@ -2270,6 +2289,21 @@ inline void ggml_sycl_op_relu(ggml_backend_sycl_context & ctx, const ggml_tensor (void) src1_dd; } +inline void ggml_sycl_op_allreduce(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1, + ggml_tensor *dst, const float *src0_dd, + const float *src1_dd, float *dst_dd, + const queue_ptr &main_stream) { + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F32); + + allreduce_f32_sycl(src0_dd, dst_dd, ggml_nelements(src0), main_stream); + + (void) src1; + (void) dst; + (void) src1_dd; +} + static void ggml_sycl_op_hardsigmoid(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1, ggml_tensor *dst, const float *src0_dd, const float *src1_dd, @@ -3179,6 +3213,13 @@ static void ggml_sycl_relu(ggml_backend_sycl_context & ctx, const ggml_tensor * GGML_SYCL_DEBUG("call %s done\n", __func__); } +static void ggml_sycl_allreduce(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + GGML_SYCL_DEBUG("call %s\n", __func__); + ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_allreduce); + GGML_SYCL_DEBUG("call %s done\n", __func__); +} + + static void ggml_sycl_hardsigmoid(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { GGML_SYCL_DEBUG("call %s\n", __func__); ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_hardsigmoid); @@ -3530,6 +3571,9 @@ static void ggml_sycl_mul_mat(ggml_backend_sycl_context & ctx, const ggml_tensor } else { ggml_sycl_op_mul_mat(ctx, src0, src1, dst, ggml_sycl_op_mul_mat_sycl, false); } + if (src0->split_mode == tensor_parallel_mode::TENSOR_SPLIT_BY_COLUMN) { + ggml_sycl_allreduce(ctx, dst, src1, dst); + } } @@ -4193,6 +4237,41 @@ catch (sycl::exception const &exc) { std::exit(1); } +static bool split_tensor(const struct ggml_tensor * src, void* dst, void* data, int split_mode) { + int rank = ggml_backend_sycl_rank() + int world_size = ggml_backend_sycl_world_size() + auto type_traits = ggml_internal_get_type_traits(src->type); + size_t element_size = type_traits.type_size / type_traits.blck_size; + const int64_t dst_size = ggml_nelements(src) * element_size / world_size; + switch (split_mode) { + case tensor_parallel_mode::TENSOR_SPLIT_BY_COLUMN: { + const int64_t nr = ggml_nrows(src); + const int64_t nc = src->ne[0]; + const int64_t ndr = nr; + const int64_t ndc = nc / world_size; + for (size_t i = 0; i < nr; ++i) { + memcpy(((char*)dst) + i * ndc * element_size, + ((char*)data) + i * nc * element_size + ndc * rank * element_size, ndc * element_size); + } + } break; + case tensor_parallel_mode::TENSOR_SPLIT_BY_ROW: { + memcpy(((char*)dst), ((char*)data) + dst_size * rank, dst_size); + } break; + case tensor_parallel_mode::TENSOR_KEEPED_ON_MASTER: { + if (rank == 0) { + memcpy(((char*)dst), ((char*)data), dst_size); + } else { + memset(((char*)dst), 0, dst_size); + } + } break; + default: { + return false; + } break; + } + return true; +} + + static void ggml_backend_sycl_buffer_set_tensor(ggml_backend_buffer_t buffer, ggml_tensor *tensor, const void *data, size_t offset, @@ -4205,7 +4284,14 @@ static void ggml_backend_sycl_buffer_set_tensor(ggml_backend_buffer_t buffer, SYCL_CHECK( CHECK_TRY_ERROR(dpct::dev_mgr::instance().get_device(ctx->device).queues_wait_and_throw())); char* host_buf = (char*)malloc(size); - memcpy(host_buf, data, size); + + if (tensor->split_mode == tensor_parallel_mode::TENSOR_NO_CHANGE) { + memcpy(host_buf, data, size); + } else { + if (!split_tensor(tensor, host_buf, data, size, tensor->split_mode)) { + std::cerr << "split tensor failed!" << std::endl; + } + } SYCL_CHECK( CHECK_TRY_ERROR((*stream).memcpy((char *)tensor->data + offset, host_buf, size) .wait())); @@ -4419,14 +4505,25 @@ ggml_backend_buffer_type_t ggml_backend_sycl_buffer_type(int device) { static bool ggml_backend_sycl_buffer_type_initialized = false; if (!ggml_backend_sycl_buffer_type_initialized) { - for (int i = 0; i < ggml_sycl_info().device_count; i++) { - auto & device_i = dpct::dev_mgr::instance().get_device(i); - queue_ptr stream = &(device_i.default_queue()); - ggml_backend_sycl_buffer_types[i] = { + if (dpct::dev_mgr::instance().world_size() > 1) { + auto rank = dpct::dev_mgr::instance().get_rank(); + auto & device_tp = dpct::dev_mgr::instance().get_device(rank); + queue_ptr stream = &(device_tp.default_queue()); + // TODO(xi): buffer_types always use 0 to avoid changes on public code + ggml_backend_sycl_buffer_types[0] = { /* .iface = */ ggml_backend_sycl_buffer_type_interface, - /* .context = */ new ggml_backend_sycl_buffer_type_context{i, GGML_SYCL_NAME + std::to_string(i), stream}, - }; - } + /* .context = */ new ggml_backend_sycl_buffer_type_context{rank, GGML_SYCL_NAME + std::to_string(rank), stream}, + }; + } else { + for (int i = 0; i < ggml_sycl_info().device_count; i++) { + auto & device_i = dpct::dev_mgr::instance().get_device(i); + queue_ptr stream = &(device_i.default_queue()); + ggml_backend_sycl_buffer_types[i] = { + /* .iface = */ ggml_backend_sycl_buffer_type_interface, + /* .context = */ new ggml_backend_sycl_buffer_type_context{i, GGML_SYCL_NAME + std::to_string(i), stream}, + }; + } + } ggml_backend_sycl_buffer_type_initialized = true; } return &ggml_backend_sycl_buffer_types[device]; diff --git a/ggml/src/ggml-sycl/dpct/helper.hpp b/ggml/src/ggml-sycl/dpct/helper.hpp index fe4a8f744..700f65cc6 100644 --- a/ggml/src/ggml-sycl/dpct/helper.hpp +++ b/ggml/src/ggml-sycl/dpct/helper.hpp @@ -15,6 +15,7 @@ #include #include +#include #include #include @@ -479,6 +480,8 @@ namespace dpct int _max_nd_range_size_i[3]; uint32_t _device_id; std::array _uuid; + uint32_t _rank; + uint32_t _world_size; }; static int get_major_version(const sycl::device &dev) @@ -870,7 +873,12 @@ namespace dpct } return -1; } - + inline int get_ccl_rank() { return _rank; } + inline int get_ccl_world_size() { return _world_size; } + inline ccl::communicator create_ccl_communicator(ccl::device dev, ccl::context ctx) { + return ccl::create_communicator(_world_size, _rank, dev, ctx, _kvs); + + } inline std::string get_preferred_gpu_platform_name() { std::string result; @@ -993,6 +1001,26 @@ namespace dpct static bool compare_backend(std::string &backend1, std::string &backend2) { return convert_backend_index(backend1) < convert_backend_index(backend2); } + + static void init_ccl() { + ccl::init(); + MPI_Init(NULL, NULL); + MPI_Comm_size(MPI_COMM_WORLD, &_world_size); + MPI_Comm_rank(MPI_COMM_WORLD, &_rank); + atexit(mpi_finalize); + ccl::kvs::address_type main_addr; + if (_rank == 0) { + _kvs = ccl::create_main_kvs(); + main_addr = _kvs->get_address(); + MPI_Bcast((void *)main_addr.data(), main_addr.size(), MPI_BYTE, 0, MPI_COMM_WORLD); + } + else { + MPI_Bcast((void *)main_addr.data(), main_addr.size(), MPI_BYTE, 0, MPI_COMM_WORLD); + _kvs = ccl::create_kvs(main_addr); + } + + } + dev_mgr() { sycl::device default_device = @@ -1050,6 +1078,7 @@ namespace dpct _cpu_device = _devs.size() - 1; } } + init_ccl(); } void check_id(unsigned int id) const { @@ -1066,6 +1095,10 @@ namespace dpct /// thread-id to device-id map. std::map _thread2dev_map; int _cpu_device = -1; + // For tensor parallelsim + int _rank = 0; + int _world_size = 1; + ccl::shared_ptr_class _kvs; }; static inline sycl::queue &get_default_queue() diff --git a/include/llama.h b/include/llama.h index 132937a07..6ce4b4c8b 100644 --- a/include/llama.h +++ b/include/llama.h @@ -204,6 +204,7 @@ extern "C" { LLAMA_SPLIT_MODE_NONE = 0, // single GPU LLAMA_SPLIT_MODE_LAYER = 1, // split layers and KV across GPUs LLAMA_SPLIT_MODE_ROW = 2, // split rows across GPUs + LLAMA_SPLIT_MODE_TENSOR = 3, // split tensors across GPUs }; // TODO: simplify (https://github.com/ggerganov/llama.cpp/pull/9294#pullrequestreview-2286561979) diff --git a/src/llama.cpp b/src/llama.cpp index 0accb1492..21cd489e8 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -2236,6 +2236,20 @@ static std::string llama_token_to_piece(const struct llama_model * model, llama_ return piece; } +static int ggml_backend_get_rank() { +#if defined(GGML_USE_SYCL) + return ggml_backend_sycl_rank(); +#endif + return 0; +} + +static int ggml_backend_get_world_size() { +#if defined(GGML_USE_SYCL) + return ggml_backend_sycl_world_size(); +#endif + return 1; +} + static ggml_backend_buffer_type_t llama_default_buffer_type_cpu(bool host_buffer) { ggml_backend_buffer_type_t buft = nullptr; @@ -4352,6 +4366,10 @@ struct llama_model_loader { int n_kv = 0; int n_tensors = 0; int n_created = 0; + // For tensor parallelism + int world_size = 1; + int rank = 0; + bool enable_tp = false; int64_t n_elements = 0; size_t n_bytes = 0; @@ -4611,6 +4629,8 @@ struct llama_model_loader { this->use_mmap = use_mmap; this->check_tensors = check_tensors; + world_size = ggml_backend_get_world_size(); + rank = ggml_backend_get_rank(); } ~llama_model_loader() { @@ -4834,11 +4854,20 @@ struct llama_model_loader { return get_tensor_meta(get_tensor_name(i)); } - struct ggml_tensor * create_tensor_for(struct ggml_context * ctx, const struct ggml_tensor * cur, bool duplicated) { + struct ggml_tensor * create_tensor_for(struct ggml_context * ctx, const struct ggml_tensor * cur, int flags) { struct ggml_tensor * tensor = ggml_dup_tensor(ctx, cur); ggml_set_name(tensor, ggml_get_name(cur)); + if (flags == TENSOR_SPLIT_BY_ROW) { + tensor->split_mode = tensor_parallel_mode::TENSOR_SPLIT_BY_ROW; + } + if (flags == TENSOR_SPLIT_BY_COLUMN) { + tensor->split_mode = tensor_parallel_mode::TENSOR_SPLIT_BY_COLUMN; + } + if (flags == TENSOR_KEEPED_ON_MASTER) { + tensor->split_mode = tensor_parallel_mode::TENSOR_KEEPED_ON_MASTER; + } - if (duplicated) { + if (flags == TENSOR_DUPLICATED) { size_data += ggml_nbytes(cur); } else { n_created++; @@ -4879,6 +4908,9 @@ struct llama_model_loader { static const int TENSOR_NOT_REQUIRED = 1; static const int TENSOR_DUPLICATED = 2; + static const int TENSOR_SPLIT_BY_ROW = 4; + static const int TENSOR_SPLIT_BY_COLUMN = 8; + static const int TENSOR_KEEPED_ON_MASTER = 12; struct ggml_tensor * create_tensor(struct ggml_context * ctx, const std::string & name, const std::vector & ne, int flags = 0) { const struct ggml_tensor * cur = check_tensor_dims(name, ne, !(flags & TENSOR_NOT_REQUIRED)); @@ -4887,7 +4919,7 @@ struct llama_model_loader { return NULL; } - return create_tensor_for(ctx, cur, flags & TENSOR_DUPLICATED); + return create_tensor_for(ctx, cur, flags); } struct ggml_tensor * create_tensor_as_view(struct ggml_context * ctx, struct ggml_tensor * base, const std::string & name, const std::vector & ne, size_t offset, bool required = true) { @@ -4963,7 +4995,7 @@ struct llama_model_loader { continue; } *first = std::min(*first, weight->offs); - *last = std::max(*last, weight->offs + ggml_nbytes(tensor)); + *last = std::max(*last, weight->offs + world_size * ggml_nbytes(tensor)); } catch(...) { // the tensor is not in the model } @@ -5060,7 +5092,6 @@ struct llama_model_loader { return false; } } - size_t n_size = ggml_nbytes(cur); if (use_mmap) { @@ -5126,9 +5157,9 @@ struct llama_model_loader { else #endif { - read_buf.resize(n_size); + read_buf.resize(n_size * world_size); file->seek(weight->offs, SEEK_SET); - file->read_raw(read_buf.data(), n_size); + file->read_raw(read_buf.data(), n_size * world_size); ggml_backend_tensor_set(cur, read_buf.data(), 0, n_size); if (check_tensors && !ggml_validate_row_data(cur->type, read_buf.data(), n_size)) { throw std::runtime_error(format("tensor '%s' has invalid data", ggml_get_name(cur))); @@ -6911,7 +6942,7 @@ static bool llm_load_tensors( } } else { ggml_backend_buffer_type_t split_buft; - if (split_mode == LLAMA_SPLIT_MODE_ROW) { + if (split_mode == LLAMA_SPLIT_MODE_ROW || split_mode == LLAMA_SPLIT_MODE_TENSOR) { split_buft = llama_default_buffer_type_split(model, main_gpu, tensor_split); } else { // LLAMA_SPLIT_MODE_NONE or LLAMA_SPLIT_MODE_LAYER in backends where it is not supported @@ -6972,8 +7003,8 @@ static bool llm_load_tensors( // create tensors for the weights { // note: cast to int64_t since we will use these for the tensor dimensions - const int64_t n_head = hparams.n_head(); - const int64_t n_head_kv = hparams.n_head_kv(); + int64_t n_head = hparams.n_head(); + int64_t n_head_kv = hparams.n_head_kv(); const int64_t n_embd = hparams.n_embd; const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(); const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(); @@ -6987,11 +7018,21 @@ static bool llm_load_tensors( const int64_t n_expert = hparams.n_expert; const int64_t n_expert_used = hparams.n_expert_used; const int64_t n_ctx_train = hparams.n_ctx_train; + int64_t head_size = n_embd / n_head; if (n_expert > 0 && hparams.n_expert_used == 0) { throw std::runtime_error("model has expert layers but no expert layers are used"); } + if (split_mode == LLAMA_SPLIT_MODE_TENSOR) { + if (world_size > 1) { + enable_tp = true; + // need to change the size before load tensor + n_head /= world_size; + n_head_kv /= world_size; + } + } + ggml_context * ctx_input = ctx_map.at(model.buft_input.buft); ggml_context * ctx_output = ctx_map.at(model.buft_output.buft); ggml_context * ctx_output_split = ctx_map.at(model.buft_output.buft_matrix); @@ -7029,31 +7070,60 @@ static bool llm_load_tensors( auto & layer = model.layers[i]; layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}); + // When enable tp create tensor with tp arr + if (enable_tp) { + layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, llama_model_loader::TENSOR_SPLIT_BY_ROW); + layer.wk = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, llama_model_loader::TENSOR_SPLIT_BY_ROW); + layer.wv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, llama_model_loader::TENSOR_SPLIT_BY_ROW); + layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, llama_model_loader::TENSOR_SPLIT_BY_COLUMN); - layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}); - layer.wk = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}); - layer.wv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}); - layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}); + // optional bias tensors + auto bias_split_mode = llama_model_loader::TENSOR_NOT_REQUIRED | llama_model_loader::TENSOR_SPLIT_BY_COLUMN + layer.bq = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, bias_split_mode); + layer.bk = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}, bias_split_mode); + layer.bv = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}, bias_split_mode); + layer.bo = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED | llama_model_loader::TENSOR_KEEPED_ON_MASTER); + } else { + layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}); + layer.wk = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}); + layer.wv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}); + layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}); - // optional bias tensors - layer.bq = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED); - layer.bk = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}, llama_model_loader::TENSOR_NOT_REQUIRED); - layer.bv = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}, llama_model_loader::TENSOR_NOT_REQUIRED); - layer.bo = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED); + // optional bias tensors + layer.bq = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.bk = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.bv = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.bo = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED); + } layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}); - layer.rope_freqs = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ROPE_FREQS, "weight"), {n_rot/2}, llama_model_loader::TENSOR_NOT_REQUIRED | (i != 0 ? llama_model_loader::TENSOR_DUPLICATED : 0)); + // TODO(chenxi) check the n_rot maybe need to split instead of head_size + layer.rope_freqs = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ROPE_FREQS, "weight"), {head_size / 2}, llama_model_loader::TENSOR_NOT_REQUIRED | (i != 0 ? llama_model_loader::TENSOR_DUPLICATED : 0)); if (n_expert == 0) { - layer.ffn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}); - layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}); - layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}); + // TODO(chenxi) only support none n_expert case + if (enable_tp) { + layer.ffn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, llama_model_loader::TENSOR_SPLIT_BY_ROW); + layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, llama_model_loader::TENSOR_SPLIT_BY_COLUMN); + layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, llama_model_loader::TENSOR_SPLIT_BY_ROW); - // optional MLP bias - layer.ffn_gate_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_GATE, "bias", i), {n_ff}, llama_model_loader::TENSOR_NOT_REQUIRED); - layer.ffn_down_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED); - layer.ffn_up_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, llama_model_loader::TENSOR_NOT_REQUIRED); + // optional MLP bias + auto bias_split_mode = llama_model_loader::TENSOR_NOT_REQUIRED | llama_model_loader::TENSOR_SPLIT_BY_COLUMN + layer.ffn_gate_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_GATE, "bias", i), {n_ff}, bias_split_mode); + layer.ffn_down_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED | llama_model_loader::TENSOR_KEEPED_ON_MASTER); + layer.ffn_up_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, bias_split_mode); + + } else { + layer.ffn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}); + layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}); + layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}); + + // optional MLP bias + layer.ffn_gate_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_GATE, "bias", i), {n_ff}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.ffn_down_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.ffn_up_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, llama_model_loader::TENSOR_NOT_REQUIRED); + } } else { layer.ffn_gate_inp = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}); @@ -8880,6 +8950,10 @@ static int llama_model_load(const std::string & fname, llama_model & model, llam llama_model_loader ml(fname, params.use_mmap, params.check_tensors, params.kv_overrides); model.hparams.vocab_only = params.vocab_only; + if (params.tensor_split == LLAMA_SPLIT_MODE_TENSOR) { + auto main_gpu = ggml_backend_get_rank(); + params.main_gpu = main_gpu; + } try { llm_load_arch(ml, model);