add tensor parallelism support to SYCL
Signed-off-by: Chen Xi <xi2chen@intel.com>
This commit is contained in:
parent
7691654c68
commit
cb8507b3b4
5 changed files with 264 additions and 36 deletions
|
@ -581,6 +581,27 @@ extern "C" {
|
||||||
GGML_TENSOR_FLAG_LOSS = 8, // ...defines loss for numerical optimization (multiple loss tensors add up)
|
GGML_TENSOR_FLAG_LOSS = 8, // ...defines loss for numerical optimization (multiple loss tensors add up)
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// ggml object
|
||||||
|
struct ggml_object {
|
||||||
|
size_t offs;
|
||||||
|
size_t size;
|
||||||
|
|
||||||
|
struct ggml_object * next;
|
||||||
|
|
||||||
|
enum ggml_object_type type;
|
||||||
|
|
||||||
|
char padding[4];
|
||||||
|
};
|
||||||
|
|
||||||
|
static const size_t GGML_OBJECT_SIZE = sizeof(struct ggml_object);
|
||||||
|
|
||||||
|
enum tensor_parallel_mode {
|
||||||
|
TENSOR_NO_CHANGE,
|
||||||
|
TENSOR_SPLIT_BY_ROW,
|
||||||
|
TENSOR_SPLIT_BY_COLUMN,
|
||||||
|
TENSOR_KEEPED_ON_MASTER,
|
||||||
|
}
|
||||||
|
|
||||||
// n-dimensional tensor
|
// n-dimensional tensor
|
||||||
struct ggml_tensor {
|
struct ggml_tensor {
|
||||||
enum ggml_type type;
|
enum ggml_type type;
|
||||||
|
@ -616,6 +637,8 @@ extern "C" {
|
||||||
|
|
||||||
void * extra; // extra things e.g. for ggml-cuda.cu
|
void * extra; // extra things e.g. for ggml-cuda.cu
|
||||||
|
|
||||||
|
enum tensor_parallel_mode split_mode = tensor_parallel_mode::TENSOR_NO_CHANGE;
|
||||||
|
|
||||||
// char padding[4];
|
// char padding[4];
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
|
@ -1239,6 +1239,15 @@ static void relu_f32_sycl(const float *x, float *dst, const int k,
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static void allreduce_f32_sycl(const float *x, float *dst, const int k,
|
||||||
|
queue_ptr stream) {
|
||||||
|
auto dev = ccl::create_device(stream->get_device());
|
||||||
|
auto ctx = ccl::create_context(stream->get_context());
|
||||||
|
auto comm = dpct::dev_mgr::instance().create_ccl_communicator(dev, ctx);
|
||||||
|
auto ccl_stream = ccl::create_stream(*stream);
|
||||||
|
ccl::allreduce(x, dst, k, ccl::reduction::sum, comm, ccl_stream).wait();
|
||||||
|
}
|
||||||
|
|
||||||
static void hardsigmoid_f32_sycl(const float *x, float *dst, const int k,
|
static void hardsigmoid_f32_sycl(const float *x, float *dst, const int k,
|
||||||
queue_ptr stream) {
|
queue_ptr stream) {
|
||||||
const int num_blocks = (k + SYCL_HARDSIGMOID_BLOCK_SIZE - 1) / SYCL_HARDSIGMOID_BLOCK_SIZE;
|
const int num_blocks = (k + SYCL_HARDSIGMOID_BLOCK_SIZE - 1) / SYCL_HARDSIGMOID_BLOCK_SIZE;
|
||||||
|
@ -1736,6 +1745,16 @@ void print_device_detail(int id, sycl::device &device, std::string device_type)
|
||||||
global_mem_size, device.get_info<sycl::info::device::driver_version>().c_str());
|
global_mem_size, device.get_info<sycl::info::device::driver_version>().c_str());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
int ggml_backend_sycl_rank() {
|
||||||
|
// use ccl rank as main gpu
|
||||||
|
return dpct::dev_mgr::instance().get_ccl_rank();
|
||||||
|
}
|
||||||
|
|
||||||
|
int ggml_backend_sycl_world_size() {
|
||||||
|
// use ccl rank as main gpu
|
||||||
|
return dpct::dev_mgr::instance().get_ccl_world_size();
|
||||||
|
}
|
||||||
|
|
||||||
void ggml_backend_sycl_print_sycl_devices() {
|
void ggml_backend_sycl_print_sycl_devices() {
|
||||||
GGML_SYCL_DEBUG("[SYCL] call ggml_backend_sycl_print_sycl_devices\n");
|
GGML_SYCL_DEBUG("[SYCL] call ggml_backend_sycl_print_sycl_devices\n");
|
||||||
int device_count = dpct::dev_mgr::instance().device_count();
|
int device_count = dpct::dev_mgr::instance().device_count();
|
||||||
|
@ -2270,6 +2289,21 @@ inline void ggml_sycl_op_relu(ggml_backend_sycl_context & ctx, const ggml_tensor
|
||||||
(void) src1_dd;
|
(void) src1_dd;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
inline void ggml_sycl_op_allreduce(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1,
|
||||||
|
ggml_tensor *dst, const float *src0_dd,
|
||||||
|
const float *src1_dd, float *dst_dd,
|
||||||
|
const queue_ptr &main_stream) {
|
||||||
|
|
||||||
|
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||||
|
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
||||||
|
|
||||||
|
allreduce_f32_sycl(src0_dd, dst_dd, ggml_nelements(src0), main_stream);
|
||||||
|
|
||||||
|
(void) src1;
|
||||||
|
(void) dst;
|
||||||
|
(void) src1_dd;
|
||||||
|
}
|
||||||
|
|
||||||
static void ggml_sycl_op_hardsigmoid(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
|
static void ggml_sycl_op_hardsigmoid(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
|
||||||
const ggml_tensor *src1, ggml_tensor *dst,
|
const ggml_tensor *src1, ggml_tensor *dst,
|
||||||
const float *src0_dd, const float *src1_dd,
|
const float *src0_dd, const float *src1_dd,
|
||||||
|
@ -3179,6 +3213,13 @@ static void ggml_sycl_relu(ggml_backend_sycl_context & ctx, const ggml_tensor *
|
||||||
GGML_SYCL_DEBUG("call %s done\n", __func__);
|
GGML_SYCL_DEBUG("call %s done\n", __func__);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static void ggml_sycl_allreduce(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||||
|
GGML_SYCL_DEBUG("call %s\n", __func__);
|
||||||
|
ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_allreduce);
|
||||||
|
GGML_SYCL_DEBUG("call %s done\n", __func__);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
static void ggml_sycl_hardsigmoid(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
static void ggml_sycl_hardsigmoid(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||||
GGML_SYCL_DEBUG("call %s\n", __func__);
|
GGML_SYCL_DEBUG("call %s\n", __func__);
|
||||||
ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_hardsigmoid);
|
ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_hardsigmoid);
|
||||||
|
@ -3530,6 +3571,9 @@ static void ggml_sycl_mul_mat(ggml_backend_sycl_context & ctx, const ggml_tensor
|
||||||
} else {
|
} else {
|
||||||
ggml_sycl_op_mul_mat(ctx, src0, src1, dst, ggml_sycl_op_mul_mat_sycl, false);
|
ggml_sycl_op_mul_mat(ctx, src0, src1, dst, ggml_sycl_op_mul_mat_sycl, false);
|
||||||
}
|
}
|
||||||
|
if (src0->split_mode == tensor_parallel_mode::TENSOR_SPLIT_BY_COLUMN) {
|
||||||
|
ggml_sycl_allreduce(ctx, dst, src1, dst);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -4193,6 +4237,41 @@ catch (sycl::exception const &exc) {
|
||||||
std::exit(1);
|
std::exit(1);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static bool split_tensor(const struct ggml_tensor * src, void* dst, void* data, int split_mode) {
|
||||||
|
int rank = ggml_backend_sycl_rank()
|
||||||
|
int world_size = ggml_backend_sycl_world_size()
|
||||||
|
auto type_traits = ggml_internal_get_type_traits(src->type);
|
||||||
|
size_t element_size = type_traits.type_size / type_traits.blck_size;
|
||||||
|
const int64_t dst_size = ggml_nelements(src) * element_size / world_size;
|
||||||
|
switch (split_mode) {
|
||||||
|
case tensor_parallel_mode::TENSOR_SPLIT_BY_COLUMN: {
|
||||||
|
const int64_t nr = ggml_nrows(src);
|
||||||
|
const int64_t nc = src->ne[0];
|
||||||
|
const int64_t ndr = nr;
|
||||||
|
const int64_t ndc = nc / world_size;
|
||||||
|
for (size_t i = 0; i < nr; ++i) {
|
||||||
|
memcpy(((char*)dst) + i * ndc * element_size,
|
||||||
|
((char*)data) + i * nc * element_size + ndc * rank * element_size, ndc * element_size);
|
||||||
|
}
|
||||||
|
} break;
|
||||||
|
case tensor_parallel_mode::TENSOR_SPLIT_BY_ROW: {
|
||||||
|
memcpy(((char*)dst), ((char*)data) + dst_size * rank, dst_size);
|
||||||
|
} break;
|
||||||
|
case tensor_parallel_mode::TENSOR_KEEPED_ON_MASTER: {
|
||||||
|
if (rank == 0) {
|
||||||
|
memcpy(((char*)dst), ((char*)data), dst_size);
|
||||||
|
} else {
|
||||||
|
memset(((char*)dst), 0, dst_size);
|
||||||
|
}
|
||||||
|
} break;
|
||||||
|
default: {
|
||||||
|
return false;
|
||||||
|
} break;
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
static void ggml_backend_sycl_buffer_set_tensor(ggml_backend_buffer_t buffer,
|
static void ggml_backend_sycl_buffer_set_tensor(ggml_backend_buffer_t buffer,
|
||||||
ggml_tensor *tensor,
|
ggml_tensor *tensor,
|
||||||
const void *data, size_t offset,
|
const void *data, size_t offset,
|
||||||
|
@ -4205,7 +4284,14 @@ static void ggml_backend_sycl_buffer_set_tensor(ggml_backend_buffer_t buffer,
|
||||||
SYCL_CHECK(
|
SYCL_CHECK(
|
||||||
CHECK_TRY_ERROR(dpct::dev_mgr::instance().get_device(ctx->device).queues_wait_and_throw()));
|
CHECK_TRY_ERROR(dpct::dev_mgr::instance().get_device(ctx->device).queues_wait_and_throw()));
|
||||||
char* host_buf = (char*)malloc(size);
|
char* host_buf = (char*)malloc(size);
|
||||||
memcpy(host_buf, data, size);
|
|
||||||
|
if (tensor->split_mode == tensor_parallel_mode::TENSOR_NO_CHANGE) {
|
||||||
|
memcpy(host_buf, data, size);
|
||||||
|
} else {
|
||||||
|
if (!split_tensor(tensor, host_buf, data, size, tensor->split_mode)) {
|
||||||
|
std::cerr << "split tensor failed!" << std::endl;
|
||||||
|
}
|
||||||
|
}
|
||||||
SYCL_CHECK(
|
SYCL_CHECK(
|
||||||
CHECK_TRY_ERROR((*stream).memcpy((char *)tensor->data + offset, host_buf, size)
|
CHECK_TRY_ERROR((*stream).memcpy((char *)tensor->data + offset, host_buf, size)
|
||||||
.wait()));
|
.wait()));
|
||||||
|
@ -4419,14 +4505,25 @@ ggml_backend_buffer_type_t ggml_backend_sycl_buffer_type(int device) {
|
||||||
static bool ggml_backend_sycl_buffer_type_initialized = false;
|
static bool ggml_backend_sycl_buffer_type_initialized = false;
|
||||||
|
|
||||||
if (!ggml_backend_sycl_buffer_type_initialized) {
|
if (!ggml_backend_sycl_buffer_type_initialized) {
|
||||||
for (int i = 0; i < ggml_sycl_info().device_count; i++) {
|
if (dpct::dev_mgr::instance().world_size() > 1) {
|
||||||
auto & device_i = dpct::dev_mgr::instance().get_device(i);
|
auto rank = dpct::dev_mgr::instance().get_rank();
|
||||||
queue_ptr stream = &(device_i.default_queue());
|
auto & device_tp = dpct::dev_mgr::instance().get_device(rank);
|
||||||
ggml_backend_sycl_buffer_types[i] = {
|
queue_ptr stream = &(device_tp.default_queue());
|
||||||
|
// TODO(xi): buffer_types always use 0 to avoid changes on public code
|
||||||
|
ggml_backend_sycl_buffer_types[0] = {
|
||||||
/* .iface = */ ggml_backend_sycl_buffer_type_interface,
|
/* .iface = */ ggml_backend_sycl_buffer_type_interface,
|
||||||
/* .context = */ new ggml_backend_sycl_buffer_type_context{i, GGML_SYCL_NAME + std::to_string(i), stream},
|
/* .context = */ new ggml_backend_sycl_buffer_type_context{rank, GGML_SYCL_NAME + std::to_string(rank), stream},
|
||||||
};
|
};
|
||||||
}
|
} else {
|
||||||
|
for (int i = 0; i < ggml_sycl_info().device_count; i++) {
|
||||||
|
auto & device_i = dpct::dev_mgr::instance().get_device(i);
|
||||||
|
queue_ptr stream = &(device_i.default_queue());
|
||||||
|
ggml_backend_sycl_buffer_types[i] = {
|
||||||
|
/* .iface = */ ggml_backend_sycl_buffer_type_interface,
|
||||||
|
/* .context = */ new ggml_backend_sycl_buffer_type_context{i, GGML_SYCL_NAME + std::to_string(i), stream},
|
||||||
|
};
|
||||||
|
}
|
||||||
|
}
|
||||||
ggml_backend_sycl_buffer_type_initialized = true;
|
ggml_backend_sycl_buffer_type_initialized = true;
|
||||||
}
|
}
|
||||||
return &ggml_backend_sycl_buffer_types[device];
|
return &ggml_backend_sycl_buffer_types[device];
|
||||||
|
|
|
@ -15,6 +15,7 @@
|
||||||
|
|
||||||
#include <sycl/sycl.hpp>
|
#include <sycl/sycl.hpp>
|
||||||
#include <sycl/half_type.hpp>
|
#include <sycl/half_type.hpp>
|
||||||
|
#include <oneapi/ccl.hpp>
|
||||||
#include <oneapi/mkl.hpp>
|
#include <oneapi/mkl.hpp>
|
||||||
#include <map>
|
#include <map>
|
||||||
|
|
||||||
|
@ -479,6 +480,8 @@ namespace dpct
|
||||||
int _max_nd_range_size_i[3];
|
int _max_nd_range_size_i[3];
|
||||||
uint32_t _device_id;
|
uint32_t _device_id;
|
||||||
std::array<unsigned char, 16> _uuid;
|
std::array<unsigned char, 16> _uuid;
|
||||||
|
uint32_t _rank;
|
||||||
|
uint32_t _world_size;
|
||||||
};
|
};
|
||||||
|
|
||||||
static int get_major_version(const sycl::device &dev)
|
static int get_major_version(const sycl::device &dev)
|
||||||
|
@ -870,7 +873,12 @@ namespace dpct
|
||||||
}
|
}
|
||||||
return -1;
|
return -1;
|
||||||
}
|
}
|
||||||
|
inline int get_ccl_rank() { return _rank; }
|
||||||
|
inline int get_ccl_world_size() { return _world_size; }
|
||||||
|
inline ccl::communicator create_ccl_communicator(ccl::device dev, ccl::context ctx) {
|
||||||
|
return ccl::create_communicator(_world_size, _rank, dev, ctx, _kvs);
|
||||||
|
|
||||||
|
}
|
||||||
inline std::string get_preferred_gpu_platform_name() {
|
inline std::string get_preferred_gpu_platform_name() {
|
||||||
std::string result;
|
std::string result;
|
||||||
|
|
||||||
|
@ -993,6 +1001,26 @@ namespace dpct
|
||||||
static bool compare_backend(std::string &backend1, std::string &backend2) {
|
static bool compare_backend(std::string &backend1, std::string &backend2) {
|
||||||
return convert_backend_index(backend1) < convert_backend_index(backend2);
|
return convert_backend_index(backend1) < convert_backend_index(backend2);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static void init_ccl() {
|
||||||
|
ccl::init();
|
||||||
|
MPI_Init(NULL, NULL);
|
||||||
|
MPI_Comm_size(MPI_COMM_WORLD, &_world_size);
|
||||||
|
MPI_Comm_rank(MPI_COMM_WORLD, &_rank);
|
||||||
|
atexit(mpi_finalize);
|
||||||
|
ccl::kvs::address_type main_addr;
|
||||||
|
if (_rank == 0) {
|
||||||
|
_kvs = ccl::create_main_kvs();
|
||||||
|
main_addr = _kvs->get_address();
|
||||||
|
MPI_Bcast((void *)main_addr.data(), main_addr.size(), MPI_BYTE, 0, MPI_COMM_WORLD);
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
MPI_Bcast((void *)main_addr.data(), main_addr.size(), MPI_BYTE, 0, MPI_COMM_WORLD);
|
||||||
|
_kvs = ccl::create_kvs(main_addr);
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
dev_mgr()
|
dev_mgr()
|
||||||
{
|
{
|
||||||
sycl::device default_device =
|
sycl::device default_device =
|
||||||
|
@ -1050,6 +1078,7 @@ namespace dpct
|
||||||
_cpu_device = _devs.size() - 1;
|
_cpu_device = _devs.size() - 1;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
init_ccl();
|
||||||
}
|
}
|
||||||
void check_id(unsigned int id) const
|
void check_id(unsigned int id) const
|
||||||
{
|
{
|
||||||
|
@ -1066,6 +1095,10 @@ namespace dpct
|
||||||
/// thread-id to device-id map.
|
/// thread-id to device-id map.
|
||||||
std::map<unsigned int, unsigned int> _thread2dev_map;
|
std::map<unsigned int, unsigned int> _thread2dev_map;
|
||||||
int _cpu_device = -1;
|
int _cpu_device = -1;
|
||||||
|
// For tensor parallelsim
|
||||||
|
int _rank = 0;
|
||||||
|
int _world_size = 1;
|
||||||
|
ccl::shared_ptr_class<ccl::kvs> _kvs;
|
||||||
};
|
};
|
||||||
|
|
||||||
static inline sycl::queue &get_default_queue()
|
static inline sycl::queue &get_default_queue()
|
||||||
|
|
|
@ -204,6 +204,7 @@ extern "C" {
|
||||||
LLAMA_SPLIT_MODE_NONE = 0, // single GPU
|
LLAMA_SPLIT_MODE_NONE = 0, // single GPU
|
||||||
LLAMA_SPLIT_MODE_LAYER = 1, // split layers and KV across GPUs
|
LLAMA_SPLIT_MODE_LAYER = 1, // split layers and KV across GPUs
|
||||||
LLAMA_SPLIT_MODE_ROW = 2, // split rows across GPUs
|
LLAMA_SPLIT_MODE_ROW = 2, // split rows across GPUs
|
||||||
|
LLAMA_SPLIT_MODE_TENSOR = 3, // split tensors across GPUs
|
||||||
};
|
};
|
||||||
|
|
||||||
// TODO: simplify (https://github.com/ggerganov/llama.cpp/pull/9294#pullrequestreview-2286561979)
|
// TODO: simplify (https://github.com/ggerganov/llama.cpp/pull/9294#pullrequestreview-2286561979)
|
||||||
|
|
128
src/llama.cpp
128
src/llama.cpp
|
@ -2236,6 +2236,20 @@ static std::string llama_token_to_piece(const struct llama_model * model, llama_
|
||||||
return piece;
|
return piece;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static int ggml_backend_get_rank() {
|
||||||
|
#if defined(GGML_USE_SYCL)
|
||||||
|
return ggml_backend_sycl_rank();
|
||||||
|
#endif
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
static int ggml_backend_get_world_size() {
|
||||||
|
#if defined(GGML_USE_SYCL)
|
||||||
|
return ggml_backend_sycl_world_size();
|
||||||
|
#endif
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
|
||||||
static ggml_backend_buffer_type_t llama_default_buffer_type_cpu(bool host_buffer) {
|
static ggml_backend_buffer_type_t llama_default_buffer_type_cpu(bool host_buffer) {
|
||||||
ggml_backend_buffer_type_t buft = nullptr;
|
ggml_backend_buffer_type_t buft = nullptr;
|
||||||
|
|
||||||
|
@ -4352,6 +4366,10 @@ struct llama_model_loader {
|
||||||
int n_kv = 0;
|
int n_kv = 0;
|
||||||
int n_tensors = 0;
|
int n_tensors = 0;
|
||||||
int n_created = 0;
|
int n_created = 0;
|
||||||
|
// For tensor parallelism
|
||||||
|
int world_size = 1;
|
||||||
|
int rank = 0;
|
||||||
|
bool enable_tp = false;
|
||||||
|
|
||||||
int64_t n_elements = 0;
|
int64_t n_elements = 0;
|
||||||
size_t n_bytes = 0;
|
size_t n_bytes = 0;
|
||||||
|
@ -4611,6 +4629,8 @@ struct llama_model_loader {
|
||||||
|
|
||||||
this->use_mmap = use_mmap;
|
this->use_mmap = use_mmap;
|
||||||
this->check_tensors = check_tensors;
|
this->check_tensors = check_tensors;
|
||||||
|
world_size = ggml_backend_get_world_size();
|
||||||
|
rank = ggml_backend_get_rank();
|
||||||
}
|
}
|
||||||
|
|
||||||
~llama_model_loader() {
|
~llama_model_loader() {
|
||||||
|
@ -4834,11 +4854,20 @@ struct llama_model_loader {
|
||||||
return get_tensor_meta(get_tensor_name(i));
|
return get_tensor_meta(get_tensor_name(i));
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ggml_tensor * create_tensor_for(struct ggml_context * ctx, const struct ggml_tensor * cur, bool duplicated) {
|
struct ggml_tensor * create_tensor_for(struct ggml_context * ctx, const struct ggml_tensor * cur, int flags) {
|
||||||
struct ggml_tensor * tensor = ggml_dup_tensor(ctx, cur);
|
struct ggml_tensor * tensor = ggml_dup_tensor(ctx, cur);
|
||||||
ggml_set_name(tensor, ggml_get_name(cur));
|
ggml_set_name(tensor, ggml_get_name(cur));
|
||||||
|
if (flags == TENSOR_SPLIT_BY_ROW) {
|
||||||
|
tensor->split_mode = tensor_parallel_mode::TENSOR_SPLIT_BY_ROW;
|
||||||
|
}
|
||||||
|
if (flags == TENSOR_SPLIT_BY_COLUMN) {
|
||||||
|
tensor->split_mode = tensor_parallel_mode::TENSOR_SPLIT_BY_COLUMN;
|
||||||
|
}
|
||||||
|
if (flags == TENSOR_KEEPED_ON_MASTER) {
|
||||||
|
tensor->split_mode = tensor_parallel_mode::TENSOR_KEEPED_ON_MASTER;
|
||||||
|
}
|
||||||
|
|
||||||
if (duplicated) {
|
if (flags == TENSOR_DUPLICATED) {
|
||||||
size_data += ggml_nbytes(cur);
|
size_data += ggml_nbytes(cur);
|
||||||
} else {
|
} else {
|
||||||
n_created++;
|
n_created++;
|
||||||
|
@ -4879,6 +4908,9 @@ struct llama_model_loader {
|
||||||
|
|
||||||
static const int TENSOR_NOT_REQUIRED = 1;
|
static const int TENSOR_NOT_REQUIRED = 1;
|
||||||
static const int TENSOR_DUPLICATED = 2;
|
static const int TENSOR_DUPLICATED = 2;
|
||||||
|
static const int TENSOR_SPLIT_BY_ROW = 4;
|
||||||
|
static const int TENSOR_SPLIT_BY_COLUMN = 8;
|
||||||
|
static const int TENSOR_KEEPED_ON_MASTER = 12;
|
||||||
|
|
||||||
struct ggml_tensor * create_tensor(struct ggml_context * ctx, const std::string & name, const std::vector<int64_t> & ne, int flags = 0) {
|
struct ggml_tensor * create_tensor(struct ggml_context * ctx, const std::string & name, const std::vector<int64_t> & ne, int flags = 0) {
|
||||||
const struct ggml_tensor * cur = check_tensor_dims(name, ne, !(flags & TENSOR_NOT_REQUIRED));
|
const struct ggml_tensor * cur = check_tensor_dims(name, ne, !(flags & TENSOR_NOT_REQUIRED));
|
||||||
|
@ -4887,7 +4919,7 @@ struct llama_model_loader {
|
||||||
return NULL;
|
return NULL;
|
||||||
}
|
}
|
||||||
|
|
||||||
return create_tensor_for(ctx, cur, flags & TENSOR_DUPLICATED);
|
return create_tensor_for(ctx, cur, flags);
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ggml_tensor * create_tensor_as_view(struct ggml_context * ctx, struct ggml_tensor * base, const std::string & name, const std::vector<int64_t> & ne, size_t offset, bool required = true) {
|
struct ggml_tensor * create_tensor_as_view(struct ggml_context * ctx, struct ggml_tensor * base, const std::string & name, const std::vector<int64_t> & ne, size_t offset, bool required = true) {
|
||||||
|
@ -4963,7 +4995,7 @@ struct llama_model_loader {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
*first = std::min(*first, weight->offs);
|
*first = std::min(*first, weight->offs);
|
||||||
*last = std::max(*last, weight->offs + ggml_nbytes(tensor));
|
*last = std::max(*last, weight->offs + world_size * ggml_nbytes(tensor));
|
||||||
} catch(...) {
|
} catch(...) {
|
||||||
// the tensor is not in the model
|
// the tensor is not in the model
|
||||||
}
|
}
|
||||||
|
@ -5060,7 +5092,6 @@ struct llama_model_loader {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t n_size = ggml_nbytes(cur);
|
size_t n_size = ggml_nbytes(cur);
|
||||||
|
|
||||||
if (use_mmap) {
|
if (use_mmap) {
|
||||||
|
@ -5126,9 +5157,9 @@ struct llama_model_loader {
|
||||||
else
|
else
|
||||||
#endif
|
#endif
|
||||||
{
|
{
|
||||||
read_buf.resize(n_size);
|
read_buf.resize(n_size * world_size);
|
||||||
file->seek(weight->offs, SEEK_SET);
|
file->seek(weight->offs, SEEK_SET);
|
||||||
file->read_raw(read_buf.data(), n_size);
|
file->read_raw(read_buf.data(), n_size * world_size);
|
||||||
ggml_backend_tensor_set(cur, read_buf.data(), 0, n_size);
|
ggml_backend_tensor_set(cur, read_buf.data(), 0, n_size);
|
||||||
if (check_tensors && !ggml_validate_row_data(cur->type, read_buf.data(), n_size)) {
|
if (check_tensors && !ggml_validate_row_data(cur->type, read_buf.data(), n_size)) {
|
||||||
throw std::runtime_error(format("tensor '%s' has invalid data", ggml_get_name(cur)));
|
throw std::runtime_error(format("tensor '%s' has invalid data", ggml_get_name(cur)));
|
||||||
|
@ -6911,7 +6942,7 @@ static bool llm_load_tensors(
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
ggml_backend_buffer_type_t split_buft;
|
ggml_backend_buffer_type_t split_buft;
|
||||||
if (split_mode == LLAMA_SPLIT_MODE_ROW) {
|
if (split_mode == LLAMA_SPLIT_MODE_ROW || split_mode == LLAMA_SPLIT_MODE_TENSOR) {
|
||||||
split_buft = llama_default_buffer_type_split(model, main_gpu, tensor_split);
|
split_buft = llama_default_buffer_type_split(model, main_gpu, tensor_split);
|
||||||
} else {
|
} else {
|
||||||
// LLAMA_SPLIT_MODE_NONE or LLAMA_SPLIT_MODE_LAYER in backends where it is not supported
|
// LLAMA_SPLIT_MODE_NONE or LLAMA_SPLIT_MODE_LAYER in backends where it is not supported
|
||||||
|
@ -6972,8 +7003,8 @@ static bool llm_load_tensors(
|
||||||
// create tensors for the weights
|
// create tensors for the weights
|
||||||
{
|
{
|
||||||
// note: cast to int64_t since we will use these for the tensor dimensions
|
// note: cast to int64_t since we will use these for the tensor dimensions
|
||||||
const int64_t n_head = hparams.n_head();
|
int64_t n_head = hparams.n_head();
|
||||||
const int64_t n_head_kv = hparams.n_head_kv();
|
int64_t n_head_kv = hparams.n_head_kv();
|
||||||
const int64_t n_embd = hparams.n_embd;
|
const int64_t n_embd = hparams.n_embd;
|
||||||
const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa();
|
const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa();
|
||||||
const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa();
|
const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa();
|
||||||
|
@ -6987,11 +7018,21 @@ static bool llm_load_tensors(
|
||||||
const int64_t n_expert = hparams.n_expert;
|
const int64_t n_expert = hparams.n_expert;
|
||||||
const int64_t n_expert_used = hparams.n_expert_used;
|
const int64_t n_expert_used = hparams.n_expert_used;
|
||||||
const int64_t n_ctx_train = hparams.n_ctx_train;
|
const int64_t n_ctx_train = hparams.n_ctx_train;
|
||||||
|
int64_t head_size = n_embd / n_head;
|
||||||
|
|
||||||
if (n_expert > 0 && hparams.n_expert_used == 0) {
|
if (n_expert > 0 && hparams.n_expert_used == 0) {
|
||||||
throw std::runtime_error("model has expert layers but no expert layers are used");
|
throw std::runtime_error("model has expert layers but no expert layers are used");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (split_mode == LLAMA_SPLIT_MODE_TENSOR) {
|
||||||
|
if (world_size > 1) {
|
||||||
|
enable_tp = true;
|
||||||
|
// need to change the size before load tensor
|
||||||
|
n_head /= world_size;
|
||||||
|
n_head_kv /= world_size;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
ggml_context * ctx_input = ctx_map.at(model.buft_input.buft);
|
ggml_context * ctx_input = ctx_map.at(model.buft_input.buft);
|
||||||
ggml_context * ctx_output = ctx_map.at(model.buft_output.buft);
|
ggml_context * ctx_output = ctx_map.at(model.buft_output.buft);
|
||||||
ggml_context * ctx_output_split = ctx_map.at(model.buft_output.buft_matrix);
|
ggml_context * ctx_output_split = ctx_map.at(model.buft_output.buft_matrix);
|
||||||
|
@ -7029,31 +7070,60 @@ static bool llm_load_tensors(
|
||||||
auto & layer = model.layers[i];
|
auto & layer = model.layers[i];
|
||||||
|
|
||||||
layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd});
|
layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd});
|
||||||
|
// When enable tp create tensor with tp arr
|
||||||
|
if (enable_tp) {
|
||||||
|
layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, llama_model_loader::TENSOR_SPLIT_BY_ROW);
|
||||||
|
layer.wk = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, llama_model_loader::TENSOR_SPLIT_BY_ROW);
|
||||||
|
layer.wv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, llama_model_loader::TENSOR_SPLIT_BY_ROW);
|
||||||
|
layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, llama_model_loader::TENSOR_SPLIT_BY_COLUMN);
|
||||||
|
|
||||||
layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head});
|
// optional bias tensors
|
||||||
layer.wk = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa});
|
auto bias_split_mode = llama_model_loader::TENSOR_NOT_REQUIRED | llama_model_loader::TENSOR_SPLIT_BY_COLUMN
|
||||||
layer.wv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa});
|
layer.bq = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, bias_split_mode);
|
||||||
layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd});
|
layer.bk = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}, bias_split_mode);
|
||||||
|
layer.bv = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}, bias_split_mode);
|
||||||
|
layer.bo = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED | llama_model_loader::TENSOR_KEEPED_ON_MASTER);
|
||||||
|
} else {
|
||||||
|
layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head});
|
||||||
|
layer.wk = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa});
|
||||||
|
layer.wv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa});
|
||||||
|
layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd});
|
||||||
|
|
||||||
// optional bias tensors
|
// optional bias tensors
|
||||||
layer.bq = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
|
layer.bq = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
|
||||||
layer.bk = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}, llama_model_loader::TENSOR_NOT_REQUIRED);
|
layer.bk = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}, llama_model_loader::TENSOR_NOT_REQUIRED);
|
||||||
layer.bv = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}, llama_model_loader::TENSOR_NOT_REQUIRED);
|
layer.bv = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}, llama_model_loader::TENSOR_NOT_REQUIRED);
|
||||||
layer.bo = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
|
layer.bo = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
|
||||||
|
|
||||||
|
}
|
||||||
layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd});
|
layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd});
|
||||||
|
|
||||||
layer.rope_freqs = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ROPE_FREQS, "weight"), {n_rot/2}, llama_model_loader::TENSOR_NOT_REQUIRED | (i != 0 ? llama_model_loader::TENSOR_DUPLICATED : 0));
|
// TODO(chenxi) check the n_rot maybe need to split instead of head_size
|
||||||
|
layer.rope_freqs = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ROPE_FREQS, "weight"), {head_size / 2}, llama_model_loader::TENSOR_NOT_REQUIRED | (i != 0 ? llama_model_loader::TENSOR_DUPLICATED : 0));
|
||||||
|
|
||||||
if (n_expert == 0) {
|
if (n_expert == 0) {
|
||||||
layer.ffn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff});
|
// TODO(chenxi) only support none n_expert case
|
||||||
layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd});
|
if (enable_tp) {
|
||||||
layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff});
|
layer.ffn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, llama_model_loader::TENSOR_SPLIT_BY_ROW);
|
||||||
|
layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, llama_model_loader::TENSOR_SPLIT_BY_COLUMN);
|
||||||
|
layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, llama_model_loader::TENSOR_SPLIT_BY_ROW);
|
||||||
|
|
||||||
// optional MLP bias
|
// optional MLP bias
|
||||||
layer.ffn_gate_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_GATE, "bias", i), {n_ff}, llama_model_loader::TENSOR_NOT_REQUIRED);
|
auto bias_split_mode = llama_model_loader::TENSOR_NOT_REQUIRED | llama_model_loader::TENSOR_SPLIT_BY_COLUMN
|
||||||
layer.ffn_down_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
|
layer.ffn_gate_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_GATE, "bias", i), {n_ff}, bias_split_mode);
|
||||||
layer.ffn_up_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, llama_model_loader::TENSOR_NOT_REQUIRED);
|
layer.ffn_down_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED | llama_model_loader::TENSOR_KEEPED_ON_MASTER);
|
||||||
|
layer.ffn_up_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, bias_split_mode);
|
||||||
|
|
||||||
|
} else {
|
||||||
|
layer.ffn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff});
|
||||||
|
layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd});
|
||||||
|
layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff});
|
||||||
|
|
||||||
|
// optional MLP bias
|
||||||
|
layer.ffn_gate_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_GATE, "bias", i), {n_ff}, llama_model_loader::TENSOR_NOT_REQUIRED);
|
||||||
|
layer.ffn_down_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
|
||||||
|
layer.ffn_up_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, llama_model_loader::TENSOR_NOT_REQUIRED);
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
layer.ffn_gate_inp = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert});
|
layer.ffn_gate_inp = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert});
|
||||||
|
|
||||||
|
@ -8880,6 +8950,10 @@ static int llama_model_load(const std::string & fname, llama_model & model, llam
|
||||||
llama_model_loader ml(fname, params.use_mmap, params.check_tensors, params.kv_overrides);
|
llama_model_loader ml(fname, params.use_mmap, params.check_tensors, params.kv_overrides);
|
||||||
|
|
||||||
model.hparams.vocab_only = params.vocab_only;
|
model.hparams.vocab_only = params.vocab_only;
|
||||||
|
if (params.tensor_split == LLAMA_SPLIT_MODE_TENSOR) {
|
||||||
|
auto main_gpu = ggml_backend_get_rank();
|
||||||
|
params.main_gpu = main_gpu;
|
||||||
|
}
|
||||||
|
|
||||||
try {
|
try {
|
||||||
llm_load_arch(ml, model);
|
llm_load_arch(ml, model);
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue