llama : make tensor_split ptr instead of array

This commit is contained in:
Georgi Gerganov 2023-07-19 10:25:41 +03:00
parent 294f424554
commit 63ba9f3306
No known key found for this signature in database
GPG key ID: 449E073F9DC10735
4 changed files with 8 additions and 4 deletions

View file

@ -586,7 +586,7 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param
lparams.n_batch = params.n_batch; lparams.n_batch = params.n_batch;
lparams.n_gpu_layers = params.n_gpu_layers; lparams.n_gpu_layers = params.n_gpu_layers;
lparams.main_gpu = params.main_gpu; lparams.main_gpu = params.main_gpu;
memcpy(lparams.tensor_split, params.tensor_split, LLAMA_MAX_DEVICES*sizeof(float)); lparams.tensor_split = params.tensor_split;
lparams.low_vram = params.low_vram; lparams.low_vram = params.low_vram;
lparams.seed = params.seed; lparams.seed = params.seed;
lparams.f16_kv = params.memory_f16; lparams.f16_kv = params.memory_f16;

View file

@ -2512,6 +2512,9 @@ void ggml_init_cublas() {
} }
void ggml_cuda_set_tensor_split(const float * tensor_split) { void ggml_cuda_set_tensor_split(const float * tensor_split) {
if (tensor_split == nullptr) {
return;
}
bool all_zero = true; bool all_zero = true;
for (int i = 0; i < g_device_count; ++i) { for (int i = 0; i < g_device_count; ++i) {
if (tensor_split[i] != 0.0f) { if (tensor_split[i] != 0.0f) {

View file

@ -847,7 +847,7 @@ struct llama_context_params llama_context_default_params() {
/*.n_batch =*/ 512, /*.n_batch =*/ 512,
/*.gpu_layers =*/ 0, /*.gpu_layers =*/ 0,
/*.main_gpu =*/ 0, /*.main_gpu =*/ 0,
/*.tensor_split =*/ {0}, /*.tensor_split =*/ nullptr,
/*.rope_freq_base =*/ 10000.0f, /*.rope_freq_base =*/ 10000.0f,
/*.rope_freq_scale =*/ 1.0f, /*.rope_freq_scale =*/ 1.0f,
/*.progress_callback =*/ nullptr, /*.progress_callback =*/ nullptr,
@ -1287,7 +1287,7 @@ static bool llama_model_load(
int n_batch, int n_batch,
int n_gpu_layers, int n_gpu_layers,
int main_gpu, int main_gpu,
float * tensor_split, const float * tensor_split,
float rope_freq_base, float rope_freq_base,
float rope_freq_scale, float rope_freq_scale,
bool low_vram, bool low_vram,

View file

@ -88,7 +88,8 @@ extern "C" {
int32_t n_batch; // prompt processing batch size int32_t n_batch; // prompt processing batch size
int32_t n_gpu_layers; // number of layers to store in VRAM int32_t n_gpu_layers; // number of layers to store in VRAM
int32_t main_gpu; // the GPU that is used for scratch and small tensors int32_t main_gpu; // the GPU that is used for scratch and small tensors
float tensor_split[LLAMA_MAX_DEVICES]; // how to split layers across multiple GPUs
const float * tensor_split; // how to split layers across multiple GPUs (size: LLAMA_MAX_DEVICES)
// ref: https://github.com/ggerganov/llama.cpp/pull/2054 // ref: https://github.com/ggerganov/llama.cpp/pull/2054
float rope_freq_base; // RoPE base frequency float rope_freq_base; // RoPE base frequency