Merge 'origin/master' into hipblas
This commit is contained in:
commit
391dd9a0e2
14 changed files with 1248 additions and 338 deletions
|
@ -289,7 +289,6 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
params.n_batch = std::stoi(argv[i]);
|
params.n_batch = std::stoi(argv[i]);
|
||||||
params.n_batch = std::min(512, params.n_batch);
|
|
||||||
} else if (arg == "--keep") {
|
} else if (arg == "--keep") {
|
||||||
if (++i >= argc) {
|
if (++i >= argc) {
|
||||||
invalid_param = true;
|
invalid_param = true;
|
||||||
|
|
|
@ -236,8 +236,7 @@ class GGMLToGGUF:
|
||||||
if len(vbytes) == 0:
|
if len(vbytes) == 0:
|
||||||
tt = 3 # Control
|
tt = 3 # Control
|
||||||
elif tokid >= 3 and tokid <= 258 and len(vbytes) == 1:
|
elif tokid >= 3 and tokid <= 258 and len(vbytes) == 1:
|
||||||
hv = hex(vbytes[0])[2:].upper()
|
vbytes = bytes(f'<0x{vbytes[0]:02X}>', encoding = 'UTF-8')
|
||||||
vbytes = bytes(f'<0x{hv}>', encoding = 'UTF-8')
|
|
||||||
tt = 6 # Byte
|
tt = 6 # Byte
|
||||||
else:
|
else:
|
||||||
vbytes = vbytes.replace(b' ', b'\xe2\x96\x81')
|
vbytes = vbytes.replace(b' ', b'\xe2\x96\x81')
|
||||||
|
|
|
@ -72,12 +72,20 @@ int main(int argc, char ** argv) {
|
||||||
fprintf(stderr, "\n");
|
fprintf(stderr, "\n");
|
||||||
}
|
}
|
||||||
|
|
||||||
if (params.embedding){
|
if (embd_inp.size() > (size_t)params.n_ctx) {
|
||||||
if (embd_inp.size() > 0) {
|
fprintf(stderr, "%s: error: prompt is longer than the context window (%zu tokens, n_ctx = %d)\n",
|
||||||
if (llama_eval(ctx, embd_inp.data(), embd_inp.size(), n_past, params.n_threads)) {
|
__func__, embd_inp.size(), params.n_ctx);
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
while (!embd_inp.empty()) {
|
||||||
|
int n_tokens = std::min(params.n_batch, (int) embd_inp.size());
|
||||||
|
if (llama_eval(ctx, embd_inp.data(), n_tokens, n_past, params.n_threads)) {
|
||||||
fprintf(stderr, "%s : failed to eval\n", __func__);
|
fprintf(stderr, "%s : failed to eval\n", __func__);
|
||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
|
n_past += n_tokens;
|
||||||
|
embd_inp.erase(embd_inp.begin(), embd_inp.begin() + n_tokens);
|
||||||
}
|
}
|
||||||
|
|
||||||
const int n_embd = llama_n_embd(ctx);
|
const int n_embd = llama_n_embd(ctx);
|
||||||
|
@ -87,7 +95,6 @@ int main(int argc, char ** argv) {
|
||||||
printf("%f ", embeddings[i]);
|
printf("%f ", embeddings[i]);
|
||||||
}
|
}
|
||||||
printf("\n");
|
printf("\n");
|
||||||
}
|
|
||||||
|
|
||||||
llama_print_timings(ctx);
|
llama_print_timings(ctx);
|
||||||
llama_free(ctx);
|
llama_free(ctx);
|
||||||
|
|
|
@ -148,7 +148,7 @@ struct cmd_params {
|
||||||
};
|
};
|
||||||
|
|
||||||
static const cmd_params cmd_params_defaults = {
|
static const cmd_params cmd_params_defaults = {
|
||||||
/* model */ {"models/7B/ggml-model-q4_0.bin"},
|
/* model */ {"models/7B/ggml-model-q4_0.gguf"},
|
||||||
/* n_prompt */ {512},
|
/* n_prompt */ {512},
|
||||||
/* n_gen */ {128},
|
/* n_gen */ {128},
|
||||||
/* n_batch */ {512},
|
/* n_batch */ {512},
|
||||||
|
@ -179,12 +179,12 @@ static void print_usage(int /* argc */, char ** argv) {
|
||||||
fprintf(stdout, " -mg i, --main-gpu <n> (default: %s)\n", join(cmd_params_defaults.main_gpu, ",").c_str());
|
fprintf(stdout, " -mg i, --main-gpu <n> (default: %s)\n", join(cmd_params_defaults.main_gpu, ",").c_str());
|
||||||
fprintf(stdout, " -lv, --low-vram <0|1> (default: %s)\n", join(cmd_params_defaults.low_vram, ",").c_str());
|
fprintf(stdout, " -lv, --low-vram <0|1> (default: %s)\n", join(cmd_params_defaults.low_vram, ",").c_str());
|
||||||
fprintf(stdout, " -mmq, --mul-mat-q <0|1> (default: %s)\n", join(cmd_params_defaults.mul_mat_q, ",").c_str());
|
fprintf(stdout, " -mmq, --mul-mat-q <0|1> (default: %s)\n", join(cmd_params_defaults.mul_mat_q, ",").c_str());
|
||||||
fprintf(stdout, " -ts, --tensor_split <ts> \n");
|
fprintf(stdout, " -ts, --tensor_split <ts0/ts1/..> \n");
|
||||||
fprintf(stdout, " -r, --repetitions <n> (default: %d)\n", cmd_params_defaults.reps);
|
fprintf(stdout, " -r, --repetitions <n> (default: %d)\n", cmd_params_defaults.reps);
|
||||||
fprintf(stdout, " -o, --output <csv|json|md|sql> (default: %s)\n", cmd_params_defaults.output_format == CSV ? "csv" : cmd_params_defaults.output_format == JSON ? "json" : "md");
|
fprintf(stdout, " -o, --output <csv|json|md|sql> (default: %s)\n", cmd_params_defaults.output_format == CSV ? "csv" : cmd_params_defaults.output_format == JSON ? "json" : cmd_params_defaults.output_format == MARKDOWN ? "md" : "sql");
|
||||||
fprintf(stdout, " -v, --verbose (default: %s)\n", cmd_params_defaults.verbose ? "1" : "0");
|
fprintf(stdout, " -v, --verbose (default: %s)\n", cmd_params_defaults.verbose ? "1" : "0");
|
||||||
fprintf(stdout, "\n");
|
fprintf(stdout, "\n");
|
||||||
fprintf(stdout, "Multiple values can be given for each parameter by separating them with ',' or by repeating the parameter.\n");
|
fprintf(stdout, "Multiple values can be given for each parameter by separating them with ',' or by specifying the parameter multiple times.\n");
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -728,7 +728,7 @@ struct markdown_printer : public printer {
|
||||||
if (!is_cpu_backend) {
|
if (!is_cpu_backend) {
|
||||||
fields.push_back("n_gpu_layers");
|
fields.push_back("n_gpu_layers");
|
||||||
}
|
}
|
||||||
if (params.n_batch.size() > 1 || params.n_threads != cmd_params_defaults.n_threads || is_cpu_backend) {
|
if (params.n_threads.size() > 1 || params.n_threads != cmd_params_defaults.n_threads || is_cpu_backend) {
|
||||||
fields.push_back("n_threads");
|
fields.push_back("n_threads");
|
||||||
}
|
}
|
||||||
if (params.n_batch.size() > 1 || params.n_batch != cmd_params_defaults.n_batch) {
|
if (params.n_batch.size() > 1 || params.n_batch != cmd_params_defaults.n_batch) {
|
||||||
|
|
|
@ -1056,33 +1056,42 @@ static json format_tokenizer_response(const std::vector<llama_token> &tokens)
|
||||||
{"tokens", tokens}};
|
{"tokens", tokens}};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
static T json_value(const json &body, const std::string &key, const T &default_value)
|
||||||
|
{
|
||||||
|
// Fallback null to default value
|
||||||
|
return body.contains(key) && !body.at(key).is_null()
|
||||||
|
? body.value(key, default_value)
|
||||||
|
: default_value;
|
||||||
|
}
|
||||||
|
|
||||||
static void parse_options_completion(const json &body, llama_server_context &llama)
|
static void parse_options_completion(const json &body, llama_server_context &llama)
|
||||||
{
|
{
|
||||||
gpt_params default_params;
|
gpt_params default_params;
|
||||||
|
|
||||||
llama.stream = body.value("stream", false);
|
llama.stream = json_value(body, "stream", false);
|
||||||
llama.params.n_predict = body.value("n_predict", default_params.n_predict);
|
llama.params.n_predict = json_value(body, "n_predict", default_params.n_predict);
|
||||||
llama.params.top_k = body.value("top_k", default_params.top_k);
|
llama.params.top_k = json_value(body, "top_k", default_params.top_k);
|
||||||
llama.params.top_p = body.value("top_p", default_params.top_p);
|
llama.params.top_p = json_value(body, "top_p", default_params.top_p);
|
||||||
llama.params.tfs_z = body.value("tfs_z", default_params.tfs_z);
|
llama.params.tfs_z = json_value(body, "tfs_z", default_params.tfs_z);
|
||||||
llama.params.typical_p = body.value("typical_p", default_params.typical_p);
|
llama.params.typical_p = json_value(body, "typical_p", default_params.typical_p);
|
||||||
llama.params.repeat_last_n = body.value("repeat_last_n", default_params.repeat_last_n);
|
llama.params.repeat_last_n = json_value(body, "repeat_last_n", default_params.repeat_last_n);
|
||||||
llama.params.temp = body.value("temperature", default_params.temp);
|
llama.params.temp = json_value(body, "temperature", default_params.temp);
|
||||||
llama.params.repeat_penalty = body.value("repeat_penalty", default_params.repeat_penalty);
|
llama.params.repeat_penalty = json_value(body, "repeat_penalty", default_params.repeat_penalty);
|
||||||
llama.params.presence_penalty = body.value("presence_penalty", default_params.presence_penalty);
|
llama.params.presence_penalty = json_value(body, "presence_penalty", default_params.presence_penalty);
|
||||||
llama.params.frequency_penalty = body.value("frequency_penalty", default_params.frequency_penalty);
|
llama.params.frequency_penalty = json_value(body, "frequency_penalty", default_params.frequency_penalty);
|
||||||
llama.params.mirostat = body.value("mirostat", default_params.mirostat);
|
llama.params.mirostat = json_value(body, "mirostat", default_params.mirostat);
|
||||||
llama.params.mirostat_tau = body.value("mirostat_tau", default_params.mirostat_tau);
|
llama.params.mirostat_tau = json_value(body, "mirostat_tau", default_params.mirostat_tau);
|
||||||
llama.params.mirostat_eta = body.value("mirostat_eta", default_params.mirostat_eta);
|
llama.params.mirostat_eta = json_value(body, "mirostat_eta", default_params.mirostat_eta);
|
||||||
llama.params.penalize_nl = body.value("penalize_nl", default_params.penalize_nl);
|
llama.params.penalize_nl = json_value(body, "penalize_nl", default_params.penalize_nl);
|
||||||
llama.params.n_keep = body.value("n_keep", default_params.n_keep);
|
llama.params.n_keep = json_value(body, "n_keep", default_params.n_keep);
|
||||||
llama.params.seed = body.value("seed", default_params.seed);
|
llama.params.seed = json_value(body, "seed", default_params.seed);
|
||||||
llama.params.prompt = body.value("prompt", default_params.prompt);
|
llama.params.prompt = json_value(body, "prompt", default_params.prompt);
|
||||||
llama.params.grammar = body.value("grammar", default_params.grammar);
|
llama.params.grammar = json_value(body, "grammar", default_params.grammar);
|
||||||
llama.params.n_probs = body.value("n_probs", default_params.n_probs);
|
llama.params.n_probs = json_value(body, "n_probs", default_params.n_probs);
|
||||||
|
|
||||||
llama.params.logit_bias.clear();
|
llama.params.logit_bias.clear();
|
||||||
if (body.value("ignore_eos", false))
|
if (json_value(body, "ignore_eos", false))
|
||||||
{
|
{
|
||||||
llama.params.logit_bias[llama_token_eos(llama.ctx)] = -INFINITY;
|
llama.params.logit_bias[llama_token_eos(llama.ctx)] = -INFINITY;
|
||||||
}
|
}
|
||||||
|
@ -1337,7 +1346,7 @@ int main(int argc, char **argv)
|
||||||
auto lock = llama.lock();
|
auto lock = llama.lock();
|
||||||
|
|
||||||
const json body = json::parse(req.body);
|
const json body = json::parse(req.body);
|
||||||
const std::string content = body.value("content", "");
|
const std::string content = json_value<std::string>(body, "content", "");
|
||||||
const std::vector<llama_token> tokens = llama_tokenize(llama.ctx, content, false);
|
const std::vector<llama_token> tokens = llama_tokenize(llama.ctx, content, false);
|
||||||
const json data = format_tokenizer_response(tokens);
|
const json data = format_tokenizer_response(tokens);
|
||||||
return res.set_content(data.dump(), "application/json"); });
|
return res.set_content(data.dump(), "application/json"); });
|
||||||
|
@ -1350,7 +1359,7 @@ int main(int argc, char **argv)
|
||||||
|
|
||||||
llama.rewind();
|
llama.rewind();
|
||||||
llama_reset_timings(llama.ctx);
|
llama_reset_timings(llama.ctx);
|
||||||
llama.params.prompt = body.value("content", "");
|
llama.params.prompt = json_value<std::string>(body, "content", "");
|
||||||
llama.params.n_predict = 0;
|
llama.params.n_predict = 0;
|
||||||
llama.loadPrompt();
|
llama.loadPrompt();
|
||||||
llama.beginCompletion();
|
llama.beginCompletion();
|
||||||
|
@ -1379,7 +1388,7 @@ int main(int argc, char **argv)
|
||||||
{
|
{
|
||||||
if (res.status == 400) {
|
if (res.status == 400) {
|
||||||
res.set_content("Invalid request", "text/plain");
|
res.set_content("Invalid request", "text/plain");
|
||||||
} else {
|
} else if (res.status != 500) {
|
||||||
res.set_content("File Not Found", "text/plain");
|
res.set_content("File Not Found", "text/plain");
|
||||||
res.status = 404;
|
res.status = 404;
|
||||||
} });
|
} });
|
||||||
|
|
|
@ -1868,10 +1868,10 @@ struct ggml_tensor * forward_batch_wo_cache_flash_attn_train(
|
||||||
t12->grad = expand(gb, ggml_permute(ctx0, t15->grad, 0, 2, 3, 1)); assert_shape_4d(t12->grad, N, n_batch, n_embd/n_head, n_head);
|
t12->grad = expand(gb, ggml_permute(ctx0, t15->grad, 0, 2, 3, 1)); assert_shape_4d(t12->grad, N, n_batch, n_embd/n_head, n_head);
|
||||||
t11->grad = expand(gb, ggml_reshape_2d(ctx0, ggml_cont(ctx0, t12->grad), N*n_batch, n_embd)); assert_shape_2d(t11->grad, N*n_batch, n_embd);
|
t11->grad = expand(gb, ggml_reshape_2d(ctx0, ggml_cont(ctx0, t12->grad), N*n_batch, n_embd)); assert_shape_2d(t11->grad, N*n_batch, n_embd);
|
||||||
t10->grad = expand(gb, ggml_permute(ctx0, t14->grad, 0, 2, 1, 3)); assert_shape_4d(t10->grad, n_embd/n_head, n_head, N, n_batch);
|
t10->grad = expand(gb, ggml_permute(ctx0, t14->grad, 0, 2, 1, 3)); assert_shape_4d(t10->grad, n_embd/n_head, n_head, N, n_batch);
|
||||||
t09->grad = expand(gb, ggml_rope_back(ctx0, t10->grad, n_past, n_rot, rope_mode, n_ctx)); assert_shape_4d(t09->grad, n_embd/n_head, n_head, N, n_batch);
|
t09->grad = expand(gb, ggml_rope_back(ctx0, t10->grad, n_past, n_rot, rope_mode, n_ctx, 10000.0f, 1.0f, 0.0f, false)); assert_shape_4d(t09->grad, n_embd/n_head, n_head, N, n_batch);
|
||||||
t08->grad = expand(gb, ggml_reshape_2d(ctx0, t09->grad, n_embd, N*n_batch)); assert_shape_2d(t08->grad, n_embd, N*n_batch);
|
t08->grad = expand(gb, ggml_reshape_2d(ctx0, t09->grad, n_embd, N*n_batch)); assert_shape_2d(t08->grad, n_embd, N*n_batch);
|
||||||
t07->grad = expand(gb, ggml_permute(ctx0, t13->grad, 0, 2, 1, 3)); assert_shape_4d(t07->grad, n_embd/n_head, n_head, N, n_batch);
|
t07->grad = expand(gb, ggml_permute(ctx0, t13->grad, 0, 2, 1, 3)); assert_shape_4d(t07->grad, n_embd/n_head, n_head, N, n_batch);
|
||||||
t06->grad = expand(gb, ggml_rope_back(ctx0, t07->grad, n_past, n_rot, rope_mode, n_ctx)); assert_shape_4d(t06->grad, n_embd/n_head, n_head, N, n_batch);
|
t06->grad = expand(gb, ggml_rope_back(ctx0, t07->grad, n_past, n_rot, rope_mode, n_ctx, 10000.0f, 1.0f, 0.0f, false)); assert_shape_4d(t06->grad, n_embd/n_head, n_head, N, n_batch);
|
||||||
t05->grad = expand(gb, ggml_reshape_2d(ctx0, t06->grad, n_embd, N*n_batch)); assert_shape_2d(t05->grad, n_embd, N*n_batch);
|
t05->grad = expand(gb, ggml_reshape_2d(ctx0, t06->grad, n_embd, N*n_batch)); assert_shape_2d(t05->grad, n_embd, N*n_batch);
|
||||||
t04->grad = expand(gb, ggml_add_inplace(ctx0,
|
t04->grad = expand(gb, ggml_add_inplace(ctx0,
|
||||||
ggml_add_inplace(ctx0,
|
ggml_add_inplace(ctx0,
|
||||||
|
|
|
@ -76,7 +76,7 @@ struct ggml_allocr {
|
||||||
};
|
};
|
||||||
|
|
||||||
#ifdef GGML_ALLOCATOR_DEBUG
|
#ifdef GGML_ALLOCATOR_DEBUG
|
||||||
static void add_allocated_tensor(struct ggml_allocator * alloc, struct ggml_tensor * tensor) {
|
static void add_allocated_tensor(struct ggml_allocr * alloc, struct ggml_tensor * tensor) {
|
||||||
for (int i = 0; i < 1024; i++) {
|
for (int i = 0; i < 1024; i++) {
|
||||||
if (alloc->allocated_tensors[i] == NULL) {
|
if (alloc->allocated_tensors[i] == NULL) {
|
||||||
alloc->allocated_tensors[i] = tensor;
|
alloc->allocated_tensors[i] = tensor;
|
||||||
|
@ -85,7 +85,7 @@ static void add_allocated_tensor(struct ggml_allocator * alloc, struct ggml_tens
|
||||||
}
|
}
|
||||||
GGML_ASSERT(!"out of allocated_tensors");
|
GGML_ASSERT(!"out of allocated_tensors");
|
||||||
}
|
}
|
||||||
static void remove_allocated_tensor(struct ggml_allocator * alloc, struct ggml_tensor * tensor) {
|
static void remove_allocated_tensor(struct ggml_allocr * alloc, struct ggml_tensor * tensor) {
|
||||||
for (int i = 0; i < 1024; i++) {
|
for (int i = 0; i < 1024; i++) {
|
||||||
if (alloc->allocated_tensors[i] == tensor ||
|
if (alloc->allocated_tensors[i] == tensor ||
|
||||||
(alloc->allocated_tensors[i] != NULL && alloc->allocated_tensors[i]->data == tensor->data)) {
|
(alloc->allocated_tensors[i] != NULL && alloc->allocated_tensors[i]->data == tensor->data)) {
|
||||||
|
|
152
ggml-cuda.cu
152
ggml-cuda.cu
|
@ -360,6 +360,7 @@ static_assert(sizeof(block_q6_K) == sizeof(ggml_fp16_t) + 13*QK_K/16, "wrong q6_
|
||||||
#define CUDA_CPY_BLOCK_SIZE 32
|
#define CUDA_CPY_BLOCK_SIZE 32
|
||||||
#define CUDA_SCALE_BLOCK_SIZE 256
|
#define CUDA_SCALE_BLOCK_SIZE 256
|
||||||
#define CUDA_ROPE_BLOCK_SIZE 256
|
#define CUDA_ROPE_BLOCK_SIZE 256
|
||||||
|
#define CUDA_ALIBI_BLOCK_SIZE 32
|
||||||
#define CUDA_DIAG_MASK_INF_BLOCK_SIZE 32
|
#define CUDA_DIAG_MASK_INF_BLOCK_SIZE 32
|
||||||
#define CUDA_QUANTIZE_BLOCK_SIZE 256
|
#define CUDA_QUANTIZE_BLOCK_SIZE 256
|
||||||
#define CUDA_DEQUANTIZE_BLOCK_SIZE 256
|
#define CUDA_DEQUANTIZE_BLOCK_SIZE 256
|
||||||
|
@ -3987,13 +3988,13 @@ static __global__ void cpy_f32_f16(const char * cx, char * cdst, const int ne,
|
||||||
// rope == RoPE == rotary positional embedding
|
// rope == RoPE == rotary positional embedding
|
||||||
static __global__ void rope_f32(const float * x, float * dst, const int ncols, const float p0,
|
static __global__ void rope_f32(const float * x, float * dst, const int ncols, const float p0,
|
||||||
const float p_delta, const int p_delta_rows, const float theta_scale) {
|
const float p_delta, const int p_delta_rows, const float theta_scale) {
|
||||||
const int col = 2*(blockDim.x*blockIdx.x + threadIdx.x);
|
const int col = 2*(blockDim.y*blockIdx.y + threadIdx.y);
|
||||||
|
|
||||||
if (col >= ncols) {
|
if (col >= ncols) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
const int row = blockDim.y*blockIdx.y + threadIdx.y;
|
const int row = blockDim.x*blockIdx.x + threadIdx.x;
|
||||||
const int i = row*ncols + col;
|
const int i = row*ncols + col;
|
||||||
|
|
||||||
const float theta = (p0 + p_delta * (row/p_delta_rows))*powf(theta_scale, col/2);
|
const float theta = (p0 + p_delta * (row/p_delta_rows))*powf(theta_scale, col/2);
|
||||||
|
@ -4041,9 +4042,32 @@ static __global__ void rope_glm_f32(const float * x, float * dst, const int ncol
|
||||||
dst[i + half_n_dims * 3] = x2*sin_block_theta + x3*cos_block_theta;
|
dst[i + half_n_dims * 3] = x2*sin_block_theta + x3*cos_block_theta;
|
||||||
}
|
}
|
||||||
|
|
||||||
static __global__ void diag_mask_inf_f32(const float * x, float * dst, const int ncols, const int rows_per_channel, const int n_past) {
|
static __global__ void alibi_f32(const float * x, float * dst, const int ncols, const int k_rows,
|
||||||
|
const int n_heads_log2_floor, const float m0, const float m1) {
|
||||||
const int col = blockDim.x*blockIdx.x + threadIdx.x;
|
const int col = blockDim.x*blockIdx.x + threadIdx.x;
|
||||||
|
|
||||||
|
if (col >= ncols) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
const int row = blockDim.y*blockIdx.y + threadIdx.y;
|
const int row = blockDim.y*blockIdx.y + threadIdx.y;
|
||||||
|
const int i = row*ncols + col;
|
||||||
|
|
||||||
|
const int k = row/k_rows;
|
||||||
|
|
||||||
|
float m_k;
|
||||||
|
if (k < n_heads_log2_floor) {
|
||||||
|
m_k = powf(m0, k + 1);
|
||||||
|
} else {
|
||||||
|
m_k = powf(m1, 2 * (k - n_heads_log2_floor) + 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
dst[i] = col * m_k + x[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
static __global__ void diag_mask_inf_f32(const float * x, float * dst, const int ncols, const int rows_per_channel, const int n_past) {
|
||||||
|
const int col = blockDim.y*blockIdx.y + threadIdx.y;
|
||||||
|
const int row = blockDim.x*blockIdx.x + threadIdx.x;
|
||||||
|
|
||||||
if (col >= ncols) {
|
if (col >= ncols) {
|
||||||
return;
|
return;
|
||||||
|
@ -4059,9 +4083,9 @@ static __global__ void diag_mask_inf_f32(const float * x, float * dst, const int
|
||||||
// values are also not normalized to the maximum value by subtracting it in the exponential function
|
// values are also not normalized to the maximum value by subtracting it in the exponential function
|
||||||
// theoretically these changes could cause problems with rounding error and arithmetic overflow but for LLaMa it seems to be fine
|
// theoretically these changes could cause problems with rounding error and arithmetic overflow but for LLaMa it seems to be fine
|
||||||
static __global__ void soft_max_f32(const float * x, float * dst, const int ncols) {
|
static __global__ void soft_max_f32(const float * x, float * dst, const int ncols) {
|
||||||
const int row = blockDim.y*blockIdx.y + threadIdx.y;
|
const int row = blockDim.x*blockIdx.x + threadIdx.x;
|
||||||
const int block_size = blockDim.x;
|
const int block_size = blockDim.y;
|
||||||
const int tid = threadIdx.x;
|
const int tid = threadIdx.y;
|
||||||
|
|
||||||
float tmp = 0.0;
|
float tmp = 0.0;
|
||||||
|
|
||||||
|
@ -4853,9 +4877,9 @@ static void scale_f32_cuda(const float * x, float * dst, const float scale, cons
|
||||||
static void rope_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float p0,
|
static void rope_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float p0,
|
||||||
const float p_delta, const int p_delta_rows, const float theta_scale, cudaStream_t stream) {
|
const float p_delta, const int p_delta_rows, const float theta_scale, cudaStream_t stream) {
|
||||||
GGML_ASSERT(nrows % 2 == 0);
|
GGML_ASSERT(nrows % 2 == 0);
|
||||||
const dim3 block_dims(2*CUDA_ROPE_BLOCK_SIZE, 1, 1);
|
const dim3 block_dims(1, 2*CUDA_ROPE_BLOCK_SIZE, 1);
|
||||||
const int num_blocks_x = (ncols + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
|
const int num_blocks_x = (ncols + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
|
||||||
const dim3 block_nums(num_blocks_x, nrows, 1);
|
const dim3 block_nums(nrows, num_blocks_x, 1);
|
||||||
rope_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols, p0, p_delta, p_delta_rows, theta_scale);
|
rope_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols, p0, p_delta, p_delta_rows, theta_scale);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -4867,16 +4891,25 @@ static void rope_glm_f32_cuda(const float * x, float * dst, const int ncols, con
|
||||||
rope_glm_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols, p, block_p, theta_scale);
|
rope_glm_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols, p, block_p, theta_scale);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static void alibi_f32_cuda(const float * x, float * dst, const int ncols, const int nrows,
|
||||||
|
const int k_rows, const int n_heads_log2_floor, const float m0,
|
||||||
|
const float m1, cudaStream_t stream) {
|
||||||
|
const dim3 block_dims(CUDA_ALIBI_BLOCK_SIZE, 1, 1);
|
||||||
|
const int num_blocks_x = (ncols + CUDA_ALIBI_BLOCK_SIZE - 1) / (CUDA_ALIBI_BLOCK_SIZE);
|
||||||
|
const dim3 block_nums(num_blocks_x, nrows, 1);
|
||||||
|
alibi_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols, k_rows, n_heads_log2_floor, m0, m1);
|
||||||
|
}
|
||||||
|
|
||||||
static void diag_mask_inf_f32_cuda(const float * x, float * dst, const int ncols_x, const int nrows_x, const int rows_per_channel, const int n_past, cudaStream_t stream) {
|
static void diag_mask_inf_f32_cuda(const float * x, float * dst, const int ncols_x, const int nrows_x, const int rows_per_channel, const int n_past, cudaStream_t stream) {
|
||||||
const dim3 block_dims(CUDA_DIAG_MASK_INF_BLOCK_SIZE, 1, 1);
|
const dim3 block_dims(1, CUDA_DIAG_MASK_INF_BLOCK_SIZE, 1);
|
||||||
const int block_num_x = (ncols_x + CUDA_DIAG_MASK_INF_BLOCK_SIZE - 1) / CUDA_DIAG_MASK_INF_BLOCK_SIZE;
|
const int block_num_x = (ncols_x + CUDA_DIAG_MASK_INF_BLOCK_SIZE - 1) / CUDA_DIAG_MASK_INF_BLOCK_SIZE;
|
||||||
const dim3 block_nums(block_num_x, nrows_x, 1);
|
const dim3 block_nums(nrows_x, block_num_x, 1);
|
||||||
diag_mask_inf_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols_x, rows_per_channel, n_past);
|
diag_mask_inf_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols_x, rows_per_channel, n_past);
|
||||||
}
|
}
|
||||||
|
|
||||||
static void soft_max_f32_cuda(const float * x, float * dst, const int ncols_x, const int nrows_x, cudaStream_t stream) {
|
static void soft_max_f32_cuda(const float * x, float * dst, const int ncols_x, const int nrows_x, cudaStream_t stream) {
|
||||||
const dim3 block_dims(WARP_SIZE, 1, 1);
|
const dim3 block_dims(1, WARP_SIZE, 1);
|
||||||
const dim3 block_nums(1, nrows_x, 1);
|
const dim3 block_nums(nrows_x, 1, 1);
|
||||||
soft_max_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols_x);
|
soft_max_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols_x);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -5610,6 +5643,41 @@ inline void ggml_cuda_op_rope(
|
||||||
(void) i1;
|
(void) i1;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
inline void ggml_cuda_op_alibi(
|
||||||
|
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, char * src0_ddq_i,
|
||||||
|
float * src0_ddf_i, float * src1_ddf_i, float * dst_ddf_i, int64_t i02, int64_t i01_low, int64_t i01_high, int i1,
|
||||||
|
cudaStream_t & cudaStream_main){
|
||||||
|
|
||||||
|
GGML_ASSERT(src0_ddf_i != nullptr);
|
||||||
|
GGML_ASSERT(dst_ddf_i != nullptr);
|
||||||
|
|
||||||
|
const int64_t ne00 = src0->ne[0];
|
||||||
|
const int64_t ne01 = src0->ne[1];
|
||||||
|
const int64_t ne02 = src0->ne[2];
|
||||||
|
const int64_t i01_diff = i01_high - i01_low;
|
||||||
|
|
||||||
|
const int n_past = ((int32_t *) dst->op_params)[0];
|
||||||
|
const int n_head = ((int32_t *) dst->op_params)[1];
|
||||||
|
float max_bias;
|
||||||
|
memcpy(&max_bias, (int32_t *) dst->op_params + 2, sizeof(float));
|
||||||
|
|
||||||
|
GGML_ASSERT(ne01 + n_past == ne00);
|
||||||
|
GGML_ASSERT(n_head == ne02);
|
||||||
|
|
||||||
|
const int n_heads_log2_floor = 1 << (int) floor(log2(n_head));
|
||||||
|
|
||||||
|
const float m0 = powf(2.0f, -(max_bias) / n_heads_log2_floor);
|
||||||
|
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_heads_log2_floor);
|
||||||
|
|
||||||
|
// compute
|
||||||
|
alibi_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, ne01, n_heads_log2_floor, m0, m1, cudaStream_main);
|
||||||
|
|
||||||
|
(void) src1;
|
||||||
|
(void) src0_ddq_i;
|
||||||
|
(void) src1_ddf_i;
|
||||||
|
(void) i1;
|
||||||
|
}
|
||||||
|
|
||||||
inline void ggml_cuda_op_diag_mask_inf(
|
inline void ggml_cuda_op_diag_mask_inf(
|
||||||
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, char * src0_ddq_i,
|
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, char * src0_ddq_i,
|
||||||
float * src0_ddf_i, float * src1_ddf_i, float * dst_ddf_i, int64_t i02, int64_t i01_low, int64_t i01_high, int i1,
|
float * src0_ddf_i, float * src1_ddf_i, float * dst_ddf_i, int64_t i02, int64_t i01_low, int64_t i01_high, int i1,
|
||||||
|
@ -6230,6 +6298,11 @@ void ggml_cuda_rope(const ggml_tensor * src0, const ggml_tensor * src1, ggml_ten
|
||||||
ggml_cuda_op(src0, src1, dst, ggml_cuda_op_rope, true, !is_glm); // flatten support not implemented for glm
|
ggml_cuda_op(src0, src1, dst, ggml_cuda_op_rope, true, !is_glm); // flatten support not implemented for glm
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void ggml_cuda_alibi(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||||
|
GGML_ASSERT(src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32);
|
||||||
|
ggml_cuda_op(src0, src1, dst, ggml_cuda_op_alibi, true, true);
|
||||||
|
}
|
||||||
|
|
||||||
void ggml_cuda_nop(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
void ggml_cuda_nop(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||||
(void) src0;
|
(void) src0;
|
||||||
(void) src1;
|
(void) src1;
|
||||||
|
@ -6349,7 +6422,7 @@ static struct ggml_tensor_extra_gpu * ggml_cuda_alloc_temp_tensor_extra() {
|
||||||
return extra;
|
return extra;
|
||||||
}
|
}
|
||||||
|
|
||||||
void ggml_cuda_assign_buffers_impl(struct ggml_tensor * tensor, bool scratch, bool force_inplace) {
|
void ggml_cuda_assign_buffers_impl(struct ggml_tensor * tensor, bool scratch, bool force_inplace, bool no_alloc) {
|
||||||
if (scratch && g_scratch_size == 0) {
|
if (scratch && g_scratch_size == 0) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
@ -6358,14 +6431,19 @@ void ggml_cuda_assign_buffers_impl(struct ggml_tensor * tensor, bool scratch, bo
|
||||||
if (tensor->src[0] != nullptr && tensor->src[0]->backend == GGML_BACKEND_CPU) {
|
if (tensor->src[0] != nullptr && tensor->src[0]->backend == GGML_BACKEND_CPU) {
|
||||||
const ggml_op src0_op = tensor->src[0]->op;
|
const ggml_op src0_op = tensor->src[0]->op;
|
||||||
if (src0_op == GGML_OP_RESHAPE || src0_op == GGML_OP_TRANSPOSE || src0_op == GGML_OP_VIEW || src0_op == GGML_OP_PERMUTE) {
|
if (src0_op == GGML_OP_RESHAPE || src0_op == GGML_OP_TRANSPOSE || src0_op == GGML_OP_VIEW || src0_op == GGML_OP_PERMUTE) {
|
||||||
ggml_cuda_assign_buffers_impl(tensor->src[0], scratch, force_inplace);
|
ggml_cuda_assign_buffers_impl(tensor->src[0], scratch, force_inplace, no_alloc);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (tensor->op == GGML_OP_CPY && tensor->src[1]->backend == GGML_BACKEND_CPU) {
|
if (tensor->op == GGML_OP_CPY && tensor->src[1]->backend == GGML_BACKEND_CPU) {
|
||||||
ggml_cuda_assign_buffers_impl(tensor->src[1], scratch, force_inplace);
|
ggml_cuda_assign_buffers_impl(tensor->src[1], scratch, force_inplace, no_alloc);
|
||||||
}
|
}
|
||||||
|
|
||||||
tensor->backend = GGML_BACKEND_GPU;
|
tensor->backend = GGML_BACKEND_GPU;
|
||||||
|
|
||||||
|
if (scratch && no_alloc) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
struct ggml_tensor_extra_gpu * extra;
|
struct ggml_tensor_extra_gpu * extra;
|
||||||
|
|
||||||
const bool inplace = (tensor->src[0] != nullptr && tensor->src[0]->data == tensor->data) ||
|
const bool inplace = (tensor->src[0] != nullptr && tensor->src[0]->data == tensor->data) ||
|
||||||
|
@ -6417,16 +6495,48 @@ void ggml_cuda_assign_buffers_impl(struct ggml_tensor * tensor, bool scratch, bo
|
||||||
tensor->extra = extra;
|
tensor->extra = extra;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void ggml_cuda_assign_scratch_offset(struct ggml_tensor * tensor, size_t offset) {
|
||||||
|
if (g_scratch_size == 0) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
if (g_scratch_buffer == nullptr) {
|
||||||
|
CUDA_CHECK(cudaMalloc(&g_scratch_buffer, g_scratch_size));
|
||||||
|
}
|
||||||
|
|
||||||
|
struct ggml_tensor_extra_gpu * extra = ggml_cuda_alloc_temp_tensor_extra();
|
||||||
|
|
||||||
|
const bool inplace = (tensor->src[0] != nullptr && tensor->src[0]->data == tensor->data) ||
|
||||||
|
tensor->op == GGML_OP_VIEW;
|
||||||
|
|
||||||
|
if (inplace && (tensor->src[0]->backend == GGML_BACKEND_GPU || tensor->src[0]->backend == GGML_BACKEND_GPU_SPLIT)) {
|
||||||
|
struct ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu * ) tensor->src[0]->extra;
|
||||||
|
char * src0_ddc = (char *) src0_extra->data_device[g_main_device];
|
||||||
|
size_t view_offset = 0;
|
||||||
|
if (tensor->op == GGML_OP_VIEW) {
|
||||||
|
memcpy(&view_offset, tensor->op_params, sizeof(size_t));
|
||||||
|
}
|
||||||
|
extra->data_device[g_main_device] = src0_ddc + view_offset;
|
||||||
|
} else {
|
||||||
|
extra->data_device[g_main_device] = (char *) g_scratch_buffer + offset;
|
||||||
|
}
|
||||||
|
|
||||||
|
tensor->extra = extra;
|
||||||
|
}
|
||||||
|
|
||||||
void ggml_cuda_assign_buffers(struct ggml_tensor * tensor) {
|
void ggml_cuda_assign_buffers(struct ggml_tensor * tensor) {
|
||||||
ggml_cuda_assign_buffers_impl(tensor, true, false);
|
ggml_cuda_assign_buffers_impl(tensor, true, false, false);
|
||||||
|
}
|
||||||
|
|
||||||
|
void ggml_cuda_assign_buffers_no_alloc(struct ggml_tensor * tensor) {
|
||||||
|
ggml_cuda_assign_buffers_impl(tensor, true, false, true);
|
||||||
}
|
}
|
||||||
|
|
||||||
void ggml_cuda_assign_buffers_no_scratch(struct ggml_tensor * tensor) {
|
void ggml_cuda_assign_buffers_no_scratch(struct ggml_tensor * tensor) {
|
||||||
ggml_cuda_assign_buffers_impl(tensor, false, false);
|
ggml_cuda_assign_buffers_impl(tensor, false, false, false);
|
||||||
}
|
}
|
||||||
|
|
||||||
void ggml_cuda_assign_buffers_force_inplace(struct ggml_tensor * tensor) {
|
void ggml_cuda_assign_buffers_force_inplace(struct ggml_tensor * tensor) {
|
||||||
ggml_cuda_assign_buffers_impl(tensor, false, true);
|
ggml_cuda_assign_buffers_impl(tensor, false, true, false);
|
||||||
}
|
}
|
||||||
|
|
||||||
void ggml_cuda_set_main_device(int main_device) {
|
void ggml_cuda_set_main_device(int main_device) {
|
||||||
|
@ -6565,6 +6675,12 @@ bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_
|
||||||
}
|
}
|
||||||
func = ggml_cuda_rope;
|
func = ggml_cuda_rope;
|
||||||
break;
|
break;
|
||||||
|
case GGML_OP_ALIBI:
|
||||||
|
if (!any_on_device) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
func = ggml_cuda_alibi;
|
||||||
|
break;
|
||||||
default:
|
default:
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
|
@ -16,9 +16,14 @@ GGML_API bool ggml_cuda_can_mul_mat(const struct ggml_tensor * src0, const str
|
||||||
GGML_API void ggml_cuda_set_tensor_split(const float * tensor_split);
|
GGML_API void ggml_cuda_set_tensor_split(const float * tensor_split);
|
||||||
GGML_API void ggml_cuda_transform_tensor(void * data, struct ggml_tensor * tensor);
|
GGML_API void ggml_cuda_transform_tensor(void * data, struct ggml_tensor * tensor);
|
||||||
GGML_API void ggml_cuda_free_data(struct ggml_tensor * tensor);
|
GGML_API void ggml_cuda_free_data(struct ggml_tensor * tensor);
|
||||||
|
|
||||||
GGML_API void ggml_cuda_assign_buffers(struct ggml_tensor * tensor);
|
GGML_API void ggml_cuda_assign_buffers(struct ggml_tensor * tensor);
|
||||||
GGML_API void ggml_cuda_assign_buffers_no_scratch(struct ggml_tensor * tensor);
|
GGML_API void ggml_cuda_assign_buffers_no_scratch(struct ggml_tensor * tensor);
|
||||||
GGML_API void ggml_cuda_assign_buffers_force_inplace(struct ggml_tensor * tensor);
|
GGML_API void ggml_cuda_assign_buffers_force_inplace(struct ggml_tensor * tensor);
|
||||||
|
|
||||||
|
GGML_API void ggml_cuda_assign_buffers_no_alloc(struct ggml_tensor * tensor);
|
||||||
|
GGML_API void ggml_cuda_assign_scratch_offset(struct ggml_tensor * tensor, size_t offset);
|
||||||
|
|
||||||
GGML_API void ggml_cuda_set_main_device(int main_device);
|
GGML_API void ggml_cuda_set_main_device(int main_device);
|
||||||
GGML_API void ggml_cuda_set_mul_mat_q(bool mul_mat_q);
|
GGML_API void ggml_cuda_set_mul_mat_q(bool mul_mat_q);
|
||||||
GGML_API void ggml_cuda_set_scratch_size(size_t scratch_size);
|
GGML_API void ggml_cuda_set_scratch_size(size_t scratch_size);
|
||||||
|
|
|
@ -1850,6 +1850,7 @@ kernel void kernel_mul_mm(device const uchar * src0,
|
||||||
//load data and store to threadgroup memory
|
//load data and store to threadgroup memory
|
||||||
half4x4 temp_a;
|
half4x4 temp_a;
|
||||||
dequantize_func(x, il, temp_a);
|
dequantize_func(x, il, temp_a);
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
#pragma unroll(16)
|
#pragma unroll(16)
|
||||||
for (int i = 0; i < 16; i++) {
|
for (int i = 0; i < 16; i++) {
|
||||||
*(sa + SG_MAT_SIZE * ((tiitg / THREAD_PER_ROW / 8) \
|
*(sa + SG_MAT_SIZE * ((tiitg / THREAD_PER_ROW / 8) \
|
||||||
|
@ -1895,14 +1896,14 @@ kernel void kernel_mul_mm(device const uchar * src0,
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// block is smaller than 64x32, we should avoid writing data outside of the matrix
|
// block is smaller than 64x32, we should avoid writing data outside of the matrix
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
threadgroup float *temp_str = ((threadgroup float *)shared_memory) \
|
threadgroup float *temp_str = ((threadgroup float *)shared_memory) \
|
||||||
+ 32 * (sgitg&1) + (16 * (sgitg>>1)) * BLOCK_SIZE_M;
|
+ 32 * (sgitg&1) + (16 * (sgitg>>1)) * BLOCK_SIZE_M;
|
||||||
for (int i = 0; i < 8; i++) {
|
for (int i = 0; i < 8; i++) {
|
||||||
threadgroup_barrier(mem_flags::mem_device);
|
|
||||||
simdgroup_store(c_res[i], temp_str + 8 * (i%4) + 8 * BLOCK_SIZE_M * (i/4), BLOCK_SIZE_M);
|
simdgroup_store(c_res[i], temp_str + 8 * (i%4) + 8 * BLOCK_SIZE_M * (i/4), BLOCK_SIZE_M);
|
||||||
}
|
}
|
||||||
|
|
||||||
threadgroup_barrier(mem_flags::mem_device);
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
device float *C = dst + BLOCK_SIZE_M * r0 + (BLOCK_SIZE_N * r1) * ne0 + im*ne1*ne0;
|
device float *C = dst + BLOCK_SIZE_M * r0 + (BLOCK_SIZE_N * r1) * ne0 + im*ne1*ne0;
|
||||||
if (sgitg==0) {
|
if (sgitg==0) {
|
||||||
for (int i = 0; i < n_rows; i++) {
|
for (int i = 0; i < n_rows; i++) {
|
||||||
|
|
120
ggml.h
120
ggml.h
|
@ -211,6 +211,7 @@
|
||||||
#define GGML_MAX_OP_PARAMS 32
|
#define GGML_MAX_OP_PARAMS 32
|
||||||
#define GGML_DEFAULT_N_THREADS 4
|
#define GGML_DEFAULT_N_THREADS 4
|
||||||
|
|
||||||
|
|
||||||
#define GGML_EXIT_SUCCESS 0
|
#define GGML_EXIT_SUCCESS 0
|
||||||
#define GGML_EXIT_ABORTED 1
|
#define GGML_EXIT_ABORTED 1
|
||||||
|
|
||||||
|
@ -259,8 +260,9 @@
|
||||||
extern "C" {
|
extern "C" {
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#ifdef __ARM_NEON
|
#if defined(__ARM_NEON) && defined(__CUDACC__)
|
||||||
// we use the built-in 16-bit float type
|
typedef half ggml_fp16_t;
|
||||||
|
#elif defined(__ARM_NEON)
|
||||||
typedef __fp16 ggml_fp16_t;
|
typedef __fp16 ggml_fp16_t;
|
||||||
#else
|
#else
|
||||||
typedef uint16_t ggml_fp16_t;
|
typedef uint16_t ggml_fp16_t;
|
||||||
|
@ -344,10 +346,12 @@ extern "C" {
|
||||||
GGML_OP_ARGMAX,
|
GGML_OP_ARGMAX,
|
||||||
GGML_OP_REPEAT,
|
GGML_OP_REPEAT,
|
||||||
GGML_OP_REPEAT_BACK,
|
GGML_OP_REPEAT_BACK,
|
||||||
|
GGML_OP_CONCAT,
|
||||||
GGML_OP_SILU_BACK,
|
GGML_OP_SILU_BACK,
|
||||||
GGML_OP_NORM, // normalize
|
GGML_OP_NORM, // normalize
|
||||||
GGML_OP_RMS_NORM,
|
GGML_OP_RMS_NORM,
|
||||||
GGML_OP_RMS_NORM_BACK,
|
GGML_OP_RMS_NORM_BACK,
|
||||||
|
GGML_OP_GROUP_NORM,
|
||||||
|
|
||||||
GGML_OP_MUL_MAT,
|
GGML_OP_MUL_MAT,
|
||||||
GGML_OP_OUT_PROD,
|
GGML_OP_OUT_PROD,
|
||||||
|
@ -373,14 +377,19 @@ extern "C" {
|
||||||
GGML_OP_CLAMP,
|
GGML_OP_CLAMP,
|
||||||
GGML_OP_CONV_1D,
|
GGML_OP_CONV_1D,
|
||||||
GGML_OP_CONV_2D,
|
GGML_OP_CONV_2D,
|
||||||
|
GGML_OP_CONV_TRANSPOSE_2D,
|
||||||
GGML_OP_POOL_1D,
|
GGML_OP_POOL_1D,
|
||||||
GGML_OP_POOL_2D,
|
GGML_OP_POOL_2D,
|
||||||
|
|
||||||
|
GGML_OP_UPSCALE, // nearest interpolate
|
||||||
|
|
||||||
GGML_OP_FLASH_ATTN,
|
GGML_OP_FLASH_ATTN,
|
||||||
GGML_OP_FLASH_FF,
|
GGML_OP_FLASH_FF,
|
||||||
GGML_OP_FLASH_ATTN_BACK,
|
GGML_OP_FLASH_ATTN_BACK,
|
||||||
GGML_OP_WIN_PART,
|
GGML_OP_WIN_PART,
|
||||||
GGML_OP_WIN_UNPART,
|
GGML_OP_WIN_UNPART,
|
||||||
|
GGML_OP_GET_REL_POS,
|
||||||
|
GGML_OP_ADD_REL_POS,
|
||||||
|
|
||||||
GGML_OP_UNARY,
|
GGML_OP_UNARY,
|
||||||
|
|
||||||
|
@ -804,6 +813,13 @@ extern "C" {
|
||||||
struct ggml_tensor * a,
|
struct ggml_tensor * a,
|
||||||
struct ggml_tensor * b);
|
struct ggml_tensor * b);
|
||||||
|
|
||||||
|
// concat a and b on dim 2
|
||||||
|
// used in stable-diffusion
|
||||||
|
GGML_API struct ggml_tensor * ggml_concat(
|
||||||
|
struct ggml_context * ctx,
|
||||||
|
struct ggml_tensor * a,
|
||||||
|
struct ggml_tensor * b);
|
||||||
|
|
||||||
GGML_API struct ggml_tensor * ggml_abs(
|
GGML_API struct ggml_tensor * ggml_abs(
|
||||||
struct ggml_context * ctx,
|
struct ggml_context * ctx,
|
||||||
struct ggml_tensor * a);
|
struct ggml_tensor * a);
|
||||||
|
@ -912,6 +928,19 @@ extern "C" {
|
||||||
struct ggml_tensor * a,
|
struct ggml_tensor * a,
|
||||||
float eps);
|
float eps);
|
||||||
|
|
||||||
|
// group normalize along ne0*ne1*n_groups
|
||||||
|
// used in stable-diffusion
|
||||||
|
// TODO: eps is hardcoded to 1e-6 for now
|
||||||
|
GGML_API struct ggml_tensor * ggml_group_norm(
|
||||||
|
struct ggml_context * ctx,
|
||||||
|
struct ggml_tensor * a,
|
||||||
|
int n_groups);
|
||||||
|
|
||||||
|
GGML_API struct ggml_tensor * ggml_group_norm_inplace(
|
||||||
|
struct ggml_context * ctx,
|
||||||
|
struct ggml_tensor * a,
|
||||||
|
int n_groups);
|
||||||
|
|
||||||
// a - x
|
// a - x
|
||||||
// b - dy
|
// b - dy
|
||||||
// TODO: update with configurable eps
|
// TODO: update with configurable eps
|
||||||
|
@ -1212,6 +1241,15 @@ extern "C" {
|
||||||
float freq_base,
|
float freq_base,
|
||||||
float freq_scale);
|
float freq_scale);
|
||||||
|
|
||||||
|
// xPos RoPE, in-place, returns view(a)
|
||||||
|
GGML_API struct ggml_tensor * ggml_rope_xpos_inplace(
|
||||||
|
struct ggml_context * ctx,
|
||||||
|
struct ggml_tensor * a,
|
||||||
|
int n_past,
|
||||||
|
int n_dims,
|
||||||
|
float base,
|
||||||
|
bool down);
|
||||||
|
|
||||||
// rotary position embedding backward, i.e compute dx from dy
|
// rotary position embedding backward, i.e compute dx from dy
|
||||||
// a - dy
|
// a - dy
|
||||||
GGML_API struct ggml_tensor * ggml_rope_back(
|
GGML_API struct ggml_tensor * ggml_rope_back(
|
||||||
|
@ -1220,7 +1258,11 @@ extern "C" {
|
||||||
int n_past,
|
int n_past,
|
||||||
int n_dims,
|
int n_dims,
|
||||||
int mode,
|
int mode,
|
||||||
int n_ctx);
|
int n_ctx,
|
||||||
|
float freq_base,
|
||||||
|
float freq_scale,
|
||||||
|
float xpos_base,
|
||||||
|
bool xpos_down);
|
||||||
|
|
||||||
// alibi position embedding
|
// alibi position embedding
|
||||||
// in-place, returns view(a)
|
// in-place, returns view(a)
|
||||||
|
@ -1247,6 +1289,15 @@ extern "C" {
|
||||||
int p0, // padding
|
int p0, // padding
|
||||||
int d0); // dilation
|
int d0); // dilation
|
||||||
|
|
||||||
|
// conv_1d with padding = half
|
||||||
|
// alias for ggml_conv_1d(a, b, s, a->ne[0]/2, d)
|
||||||
|
GGML_API struct ggml_tensor* ggml_conv_1d_ph(
|
||||||
|
struct ggml_context * ctx,
|
||||||
|
struct ggml_tensor * a,
|
||||||
|
struct ggml_tensor * b,
|
||||||
|
int s,
|
||||||
|
int d);
|
||||||
|
|
||||||
GGML_API struct ggml_tensor * ggml_conv_2d(
|
GGML_API struct ggml_tensor * ggml_conv_2d(
|
||||||
struct ggml_context * ctx,
|
struct ggml_context * ctx,
|
||||||
struct ggml_tensor * a,
|
struct ggml_tensor * a,
|
||||||
|
@ -1258,14 +1309,38 @@ extern "C" {
|
||||||
int d0,
|
int d0,
|
||||||
int d1);
|
int d1);
|
||||||
|
|
||||||
// conv_1d with padding = half
|
|
||||||
// alias for ggml_conv_1d(a, b, s, a->ne[0]/2, d)
|
// kernel size is a->ne[0] x a->ne[1]
|
||||||
GGML_API struct ggml_tensor * ggml_conv_1d_ph(
|
// stride is equal to kernel size
|
||||||
|
// padding is zero
|
||||||
|
// example:
|
||||||
|
// a: 16 16 3 768
|
||||||
|
// b: 1024 1024 3 1
|
||||||
|
// res: 64 64 768 1
|
||||||
|
// used in sam
|
||||||
|
GGML_API struct ggml_tensor * ggml_conv_2d_sk_p0(
|
||||||
|
struct ggml_context * ctx,
|
||||||
|
struct ggml_tensor * a,
|
||||||
|
struct ggml_tensor * b);
|
||||||
|
|
||||||
|
// kernel size is a->ne[0] x a->ne[1]
|
||||||
|
// stride is 1
|
||||||
|
// padding is half
|
||||||
|
// example:
|
||||||
|
// a: 3 3 256 256
|
||||||
|
// b: 64 64 256 1
|
||||||
|
// res: 64 64 256 1
|
||||||
|
// used in sam
|
||||||
|
GGML_API struct ggml_tensor * ggml_conv_2d_s1_ph(
|
||||||
|
struct ggml_context * ctx,
|
||||||
|
struct ggml_tensor * a,
|
||||||
|
struct ggml_tensor * b);
|
||||||
|
|
||||||
|
GGML_API struct ggml_tensor * ggml_conv_transpose_2d_p0(
|
||||||
struct ggml_context * ctx,
|
struct ggml_context * ctx,
|
||||||
struct ggml_tensor * a,
|
struct ggml_tensor * a,
|
||||||
struct ggml_tensor * b,
|
struct ggml_tensor * b,
|
||||||
int s,
|
int stride);
|
||||||
int d);
|
|
||||||
|
|
||||||
enum ggml_op_pool {
|
enum ggml_op_pool {
|
||||||
GGML_OP_POOL_MAX,
|
GGML_OP_POOL_MAX,
|
||||||
|
@ -1292,6 +1367,13 @@ extern "C" {
|
||||||
int p0,
|
int p0,
|
||||||
int p1);
|
int p1);
|
||||||
|
|
||||||
|
// nearest interpolate
|
||||||
|
// used in stable-diffusion
|
||||||
|
GGML_API struct ggml_tensor * ggml_upscale(
|
||||||
|
struct ggml_context * ctx,
|
||||||
|
struct ggml_tensor * a,
|
||||||
|
int scale_factor);
|
||||||
|
|
||||||
GGML_API struct ggml_tensor * ggml_flash_attn(
|
GGML_API struct ggml_tensor * ggml_flash_attn(
|
||||||
struct ggml_context * ctx,
|
struct ggml_context * ctx,
|
||||||
struct ggml_tensor * q,
|
struct ggml_tensor * q,
|
||||||
|
@ -1345,6 +1427,27 @@ extern "C" {
|
||||||
struct ggml_tensor * a,
|
struct ggml_tensor * a,
|
||||||
enum ggml_unary_op op);
|
enum ggml_unary_op op);
|
||||||
|
|
||||||
|
// used in sam
|
||||||
|
GGML_API struct ggml_tensor * ggml_get_rel_pos(
|
||||||
|
struct ggml_context * ctx,
|
||||||
|
struct ggml_tensor * a,
|
||||||
|
int qh,
|
||||||
|
int kh);
|
||||||
|
|
||||||
|
// used in sam
|
||||||
|
|
||||||
|
GGML_API struct ggml_tensor * ggml_add_rel_pos(
|
||||||
|
struct ggml_context * ctx,
|
||||||
|
struct ggml_tensor * a,
|
||||||
|
struct ggml_tensor * pw,
|
||||||
|
struct ggml_tensor * ph);
|
||||||
|
|
||||||
|
GGML_API struct ggml_tensor * ggml_add_rel_pos_inplace(
|
||||||
|
struct ggml_context * ctx,
|
||||||
|
struct ggml_tensor * a,
|
||||||
|
struct ggml_tensor * pw,
|
||||||
|
struct ggml_tensor * ph);
|
||||||
|
|
||||||
// custom operators
|
// custom operators
|
||||||
|
|
||||||
typedef void (*ggml_unary_op_f32_t) (const int, float *, const float *);
|
typedef void (*ggml_unary_op_f32_t) (const int, float *, const float *);
|
||||||
|
@ -1499,6 +1602,7 @@ extern "C" {
|
||||||
struct ggml_context * ctx,
|
struct ggml_context * ctx,
|
||||||
struct ggml_tensor * tensor);
|
struct ggml_tensor * tensor);
|
||||||
|
|
||||||
|
|
||||||
GGML_API void ggml_build_forward_expand(struct ggml_cgraph * cgraph, struct ggml_tensor * tensor);
|
GGML_API void ggml_build_forward_expand(struct ggml_cgraph * cgraph, struct ggml_tensor * tensor);
|
||||||
|
|
||||||
GGML_API struct ggml_cgraph ggml_build_forward (struct ggml_tensor * tensor);
|
GGML_API struct ggml_cgraph ggml_build_forward (struct ggml_tensor * tensor);
|
||||||
|
|
235
llama.cpp
235
llama.cpp
|
@ -10,13 +10,7 @@
|
||||||
|
|
||||||
#include "ggml.h"
|
#include "ggml.h"
|
||||||
|
|
||||||
#if !defined(GGML_USE_CUBLAS)
|
|
||||||
#include "ggml-alloc.h"
|
#include "ggml-alloc.h"
|
||||||
# define LLAMA_USE_ALLOCATOR
|
|
||||||
#else
|
|
||||||
# define LLAMA_USE_SCRATCH
|
|
||||||
# define LLAMA_MAX_SCRATCH_BUFFERS 16
|
|
||||||
#endif
|
|
||||||
|
|
||||||
#ifdef GGML_USE_CUBLAS
|
#ifdef GGML_USE_CUBLAS
|
||||||
# include "ggml-cuda.h"
|
# include "ggml-cuda.h"
|
||||||
|
@ -588,14 +582,6 @@ struct llama_state {
|
||||||
|
|
||||||
static llama_state g_state;
|
static llama_state g_state;
|
||||||
|
|
||||||
//
|
|
||||||
// memory sizes (calculated for n_batch == 512)
|
|
||||||
//
|
|
||||||
|
|
||||||
// computed for n_ctx == 2048
|
|
||||||
// TODO: dynamically determine these sizes
|
|
||||||
// needs modifications in ggml
|
|
||||||
|
|
||||||
// available llama models
|
// available llama models
|
||||||
enum e_model {
|
enum e_model {
|
||||||
MODEL_UNKNOWN,
|
MODEL_UNKNOWN,
|
||||||
|
@ -610,76 +596,6 @@ enum e_model {
|
||||||
static const size_t kB = 1024;
|
static const size_t kB = 1024;
|
||||||
static const size_t MB = 1024*1024;
|
static const size_t MB = 1024*1024;
|
||||||
|
|
||||||
static std::map<e_model, size_t> MEM_REQ_SCRATCH0(int n_ctx)
|
|
||||||
{
|
|
||||||
std::map<e_model, size_t> k_sizes = {
|
|
||||||
{ MODEL_3B, ((size_t) n_ctx / 16ull + 92ull) * MB },
|
|
||||||
{ MODEL_7B, ((size_t) n_ctx / 16ull + 100ull) * MB },
|
|
||||||
{ MODEL_13B, ((size_t) n_ctx / 12ull + 120ull) * MB },
|
|
||||||
{ MODEL_30B, ((size_t) n_ctx / 9ull + 160ull) * MB },
|
|
||||||
{ MODEL_65B, ((size_t) n_ctx / 6ull + 256ull) * MB }, // guess
|
|
||||||
{ MODEL_70B, ((size_t) n_ctx / 7ull + 164ull) * MB },
|
|
||||||
};
|
|
||||||
return k_sizes;
|
|
||||||
}
|
|
||||||
|
|
||||||
static const std::map<e_model, size_t> & MEM_REQ_SCRATCH1()
|
|
||||||
{
|
|
||||||
static std::map<e_model, size_t> k_sizes = {
|
|
||||||
{ MODEL_3B, 128ull * MB },
|
|
||||||
{ MODEL_7B, 160ull * MB },
|
|
||||||
{ MODEL_13B, 192ull * MB },
|
|
||||||
{ MODEL_30B, 256ull * MB },
|
|
||||||
{ MODEL_65B, 384ull * MB }, // guess
|
|
||||||
{ MODEL_70B, 304ull * MB },
|
|
||||||
};
|
|
||||||
return k_sizes;
|
|
||||||
}
|
|
||||||
|
|
||||||
// used to store the compute graph tensors + non-scratch data
|
|
||||||
static const std::map<e_model, size_t> & MEM_REQ_EVAL()
|
|
||||||
{
|
|
||||||
static std::map<e_model, size_t> k_sizes = {
|
|
||||||
{ MODEL_3B, 8ull * MB },
|
|
||||||
{ MODEL_7B, 10ull * MB },
|
|
||||||
{ MODEL_13B, 12ull * MB },
|
|
||||||
{ MODEL_30B, 16ull * MB },
|
|
||||||
{ MODEL_65B, 24ull * MB }, // guess
|
|
||||||
{ MODEL_70B, 24ull * MB },
|
|
||||||
};
|
|
||||||
return k_sizes;
|
|
||||||
}
|
|
||||||
|
|
||||||
// amount of VRAM needed per batch size to hold temporary results
|
|
||||||
// the values for 3b are not derived from testing but instead chosen conservatively
|
|
||||||
static const std::map<e_model, size_t> & VRAM_REQ_SCRATCH_BASE()
|
|
||||||
{
|
|
||||||
static std::map<e_model, size_t> k_sizes = {
|
|
||||||
{ MODEL_3B, 512ull * kB },
|
|
||||||
{ MODEL_7B, 512ull * kB },
|
|
||||||
{ MODEL_13B, 640ull * kB },
|
|
||||||
{ MODEL_30B, 768ull * kB },
|
|
||||||
{ MODEL_65B, 1280ull * kB },
|
|
||||||
{ MODEL_70B, 1280ull * kB },
|
|
||||||
};
|
|
||||||
return k_sizes;
|
|
||||||
}
|
|
||||||
|
|
||||||
// amount of VRAM needed per batch size and context to hold temporary results
|
|
||||||
// the values for 3b are not derived from testing but instead chosen conservatively
|
|
||||||
static const std::map<e_model, size_t> & VRAM_REQ_SCRATCH_PER_CONTEXT()
|
|
||||||
{
|
|
||||||
static std::map<e_model, size_t> k_sizes = {
|
|
||||||
{ MODEL_3B, 128ull },
|
|
||||||
{ MODEL_7B, 128ull },
|
|
||||||
{ MODEL_13B, 160ull },
|
|
||||||
{ MODEL_30B, 208ull },
|
|
||||||
{ MODEL_65B, 256ull },
|
|
||||||
{ MODEL_70B, 256ull },
|
|
||||||
};
|
|
||||||
return k_sizes;
|
|
||||||
}
|
|
||||||
|
|
||||||
// default hparams (LLaMA 7B)
|
// default hparams (LLaMA 7B)
|
||||||
struct llama_hparams {
|
struct llama_hparams {
|
||||||
uint32_t n_vocab = 32000;
|
uint32_t n_vocab = 32000;
|
||||||
|
@ -857,11 +773,9 @@ struct llama_context {
|
||||||
ggml_metal_free(ctx_metal);
|
ggml_metal_free(ctx_metal);
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
#ifdef LLAMA_USE_ALLOCATOR
|
|
||||||
if (alloc) {
|
if (alloc) {
|
||||||
ggml_allocr_free(alloc);
|
ggml_allocr_free(alloc);
|
||||||
}
|
}
|
||||||
#endif
|
|
||||||
}
|
}
|
||||||
|
|
||||||
std::mt19937 rng;
|
std::mt19937 rng;
|
||||||
|
@ -901,17 +815,8 @@ struct llama_context {
|
||||||
// memory buffers used to evaluate the model
|
// memory buffers used to evaluate the model
|
||||||
llama_buffer buf_compute;
|
llama_buffer buf_compute;
|
||||||
|
|
||||||
#ifdef LLAMA_USE_ALLOCATOR
|
|
||||||
llama_buffer buf_alloc;
|
llama_buffer buf_alloc;
|
||||||
ggml_allocr * alloc = NULL;
|
ggml_allocr * alloc = NULL;
|
||||||
#endif
|
|
||||||
|
|
||||||
#ifdef LLAMA_USE_SCRATCH
|
|
||||||
llama_buffer buf_scratch[LLAMA_MAX_SCRATCH_BUFFERS];
|
|
||||||
|
|
||||||
int buf_last = 0;
|
|
||||||
size_t buf_max_size[LLAMA_MAX_SCRATCH_BUFFERS] = { 0 };
|
|
||||||
#endif
|
|
||||||
|
|
||||||
#ifdef GGML_USE_METAL
|
#ifdef GGML_USE_METAL
|
||||||
ggml_metal_context * ctx_metal = NULL;
|
ggml_metal_context * ctx_metal = NULL;
|
||||||
|
@ -920,37 +825,6 @@ struct llama_context {
|
||||||
#ifdef GGML_USE_MPI
|
#ifdef GGML_USE_MPI
|
||||||
ggml_mpi_context * ctx_mpi = NULL;
|
ggml_mpi_context * ctx_mpi = NULL;
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
void use_buf(struct ggml_context * ctx, int i) { // NOLINT
|
|
||||||
#if defined(LLAMA_USE_SCRATCH)
|
|
||||||
size_t last_size = 0;
|
|
||||||
|
|
||||||
if (i == -1) {
|
|
||||||
last_size = ggml_set_scratch(ctx, { 0, 0, nullptr, });
|
|
||||||
} else {
|
|
||||||
auto & buf = buf_scratch[i];
|
|
||||||
last_size = ggml_set_scratch(ctx, { 0, buf.size, buf.data, });
|
|
||||||
}
|
|
||||||
|
|
||||||
if (buf_last >= 0) {
|
|
||||||
buf_max_size[buf_last] = std::max(buf_max_size[buf_last], last_size);
|
|
||||||
}
|
|
||||||
|
|
||||||
buf_last = i;
|
|
||||||
#else
|
|
||||||
(void) i;
|
|
||||||
(void) ctx;
|
|
||||||
#endif
|
|
||||||
}
|
|
||||||
|
|
||||||
size_t get_buf_max_mem(int i) { // NOLINT
|
|
||||||
#if defined(LLAMA_USE_SCRATCH)
|
|
||||||
return buf_max_size[i];
|
|
||||||
#else
|
|
||||||
(void) i;
|
|
||||||
return 0;
|
|
||||||
#endif
|
|
||||||
}
|
|
||||||
};
|
};
|
||||||
|
|
||||||
//
|
//
|
||||||
|
@ -1620,7 +1494,6 @@ static void llama_model_load_internal(
|
||||||
|
|
||||||
// prepare memory for the weights
|
// prepare memory for the weights
|
||||||
size_t vram_weights = 0;
|
size_t vram_weights = 0;
|
||||||
size_t vram_scratch = 0;
|
|
||||||
{
|
{
|
||||||
const uint32_t n_embd = hparams.n_embd;
|
const uint32_t n_embd = hparams.n_embd;
|
||||||
const uint32_t n_embd_gqa = hparams.n_embd_gqa();
|
const uint32_t n_embd_gqa = hparams.n_embd_gqa();
|
||||||
|
@ -1701,13 +1574,6 @@ static void llama_model_load_internal(
|
||||||
ctx_size +
|
ctx_size +
|
||||||
mmapped_size - vram_weights; // weights in VRAM not in memory
|
mmapped_size - vram_weights; // weights in VRAM not in memory
|
||||||
|
|
||||||
#ifndef LLAMA_USE_ALLOCATOR
|
|
||||||
mem_required +=
|
|
||||||
MEM_REQ_SCRATCH0(hparams.n_ctx).at(model.type) +
|
|
||||||
MEM_REQ_SCRATCH1().at(model.type) +
|
|
||||||
MEM_REQ_EVAL().at(model.type);
|
|
||||||
#endif
|
|
||||||
|
|
||||||
// this is the memory required by one llama_state
|
// this is the memory required by one llama_state
|
||||||
const size_t mem_required_state =
|
const size_t mem_required_state =
|
||||||
scale*hparams.kv_size();
|
scale*hparams.kv_size();
|
||||||
|
@ -1715,24 +1581,7 @@ static void llama_model_load_internal(
|
||||||
LLAMA_LOG_INFO("%s: mem required = %7.2f MB (+ %7.2f MB per state)\n", __func__,
|
LLAMA_LOG_INFO("%s: mem required = %7.2f MB (+ %7.2f MB per state)\n", __func__,
|
||||||
mem_required / 1024.0 / 1024.0, mem_required_state / 1024.0 / 1024.0);
|
mem_required / 1024.0 / 1024.0, mem_required_state / 1024.0 / 1024.0);
|
||||||
|
|
||||||
(void) vram_scratch;
|
|
||||||
(void) n_batch;
|
(void) n_batch;
|
||||||
#ifdef GGML_USE_CUBLAS
|
|
||||||
if (low_vram) {
|
|
||||||
LLAMA_LOG_INFO("%s: not allocating a VRAM scratch buffer due to low VRAM option\n", __func__);
|
|
||||||
ggml_cuda_set_scratch_size(0); // disable scratch
|
|
||||||
} else {
|
|
||||||
const size_t vram_scratch_base = VRAM_REQ_SCRATCH_BASE().at(model.type);
|
|
||||||
const size_t vram_scratch_per_context = VRAM_REQ_SCRATCH_PER_CONTEXT().at(model.type);
|
|
||||||
vram_scratch = n_batch * (vram_scratch_base + n_ctx * vram_scratch_per_context);
|
|
||||||
ggml_cuda_set_scratch_size(vram_scratch);
|
|
||||||
if (n_gpu_layers > 0) {
|
|
||||||
LLAMA_LOG_INFO("%s: allocating batch_size x (%zd kB + n_ctx x %zd B) = %zd MB VRAM for the scratch buffer\n",
|
|
||||||
__func__, vram_scratch_base / kB, vram_scratch_per_context,
|
|
||||||
(vram_scratch + MB - 1) / MB); // round up
|
|
||||||
}
|
|
||||||
}
|
|
||||||
#endif // GGML_USE_CUBLAS
|
|
||||||
|
|
||||||
#if defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST)
|
#if defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST)
|
||||||
const int n_gpu = std::min(n_gpu_layers, int(hparams.n_layer));
|
const int n_gpu = std::min(n_gpu_layers, int(hparams.n_layer));
|
||||||
|
@ -1769,8 +1618,8 @@ static void llama_model_load_internal(
|
||||||
|
|
||||||
LLAMA_LOG_INFO("%s: offloaded %d/%d layers to GPU\n",
|
LLAMA_LOG_INFO("%s: offloaded %d/%d layers to GPU\n",
|
||||||
__func__, std::min(n_gpu_layers, max_offloadable_layers), max_backend_supported_layers);
|
__func__, std::min(n_gpu_layers, max_offloadable_layers), max_backend_supported_layers);
|
||||||
LLAMA_LOG_INFO("%s: total VRAM used: %zu MB\n",
|
LLAMA_LOG_INFO("%s: VRAM used: %zu MB\n",
|
||||||
__func__, (vram_weights + vram_scratch + vram_kv_cache + MB - 1) / MB); // round up
|
__func__, (vram_weights + vram_kv_cache + MB - 1) / MB); // round up
|
||||||
#else
|
#else
|
||||||
(void) n_gpu_layers;
|
(void) n_gpu_layers;
|
||||||
#endif // defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST)
|
#endif // defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST)
|
||||||
|
@ -1875,9 +1724,7 @@ static struct ggml_cgraph * llama_build_graph(
|
||||||
/*.no_alloc =*/ false,
|
/*.no_alloc =*/ false,
|
||||||
};
|
};
|
||||||
|
|
||||||
#ifdef LLAMA_USE_ALLOCATOR
|
|
||||||
params.no_alloc = true;
|
params.no_alloc = true;
|
||||||
#endif
|
|
||||||
|
|
||||||
struct ggml_context * ctx0 = ggml_init(params);
|
struct ggml_context * ctx0 = ggml_init(params);
|
||||||
|
|
||||||
|
@ -1889,14 +1736,10 @@ static struct ggml_cgraph * llama_build_graph(
|
||||||
if (tokens) {
|
if (tokens) {
|
||||||
struct ggml_tensor * inp_tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
|
struct ggml_tensor * inp_tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
|
||||||
|
|
||||||
#ifdef LLAMA_USE_ALLOCATOR
|
|
||||||
ggml_allocr_alloc(lctx.alloc, inp_tokens);
|
ggml_allocr_alloc(lctx.alloc, inp_tokens);
|
||||||
if (!ggml_allocr_is_measure(lctx.alloc)) {
|
if (!ggml_allocr_is_measure(lctx.alloc)) {
|
||||||
memcpy(inp_tokens->data, tokens, N*ggml_element_size(inp_tokens));
|
memcpy(inp_tokens->data, tokens, N*ggml_element_size(inp_tokens));
|
||||||
}
|
}
|
||||||
#else
|
|
||||||
memcpy(inp_tokens->data, tokens, N*ggml_element_size(inp_tokens));
|
|
||||||
#endif
|
|
||||||
ggml_set_name(inp_tokens, "inp_tokens");
|
ggml_set_name(inp_tokens, "inp_tokens");
|
||||||
|
|
||||||
inpL = ggml_get_rows(ctx0, model.tok_embeddings, inp_tokens);
|
inpL = ggml_get_rows(ctx0, model.tok_embeddings, inp_tokens);
|
||||||
|
@ -1907,14 +1750,10 @@ static struct ggml_cgraph * llama_build_graph(
|
||||||
|
|
||||||
inpL = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, N);
|
inpL = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, N);
|
||||||
|
|
||||||
#ifdef LLAMA_USE_ALLOCATOR
|
|
||||||
ggml_allocr_alloc(lctx.alloc, inpL);
|
ggml_allocr_alloc(lctx.alloc, inpL);
|
||||||
if (!ggml_allocr_is_measure(lctx.alloc)) {
|
if (!ggml_allocr_is_measure(lctx.alloc)) {
|
||||||
memcpy(inpL->data, embd, N * n_embd * ggml_element_size(inpL));
|
memcpy(inpL->data, embd, N * n_embd * ggml_element_size(inpL));
|
||||||
}
|
}
|
||||||
#else
|
|
||||||
memcpy(inpL->data, embd, N * n_embd * ggml_element_size(inpL));
|
|
||||||
#endif
|
|
||||||
}
|
}
|
||||||
|
|
||||||
const int i_gpu_start = n_layer - n_gpu_layers;
|
const int i_gpu_start = n_layer - n_gpu_layers;
|
||||||
|
@ -1931,25 +1770,21 @@ static struct ggml_cgraph * llama_build_graph(
|
||||||
|
|
||||||
#ifdef GGML_USE_CUBLAS
|
#ifdef GGML_USE_CUBLAS
|
||||||
if (n_gpu_layers > n_layer) {
|
if (n_gpu_layers > n_layer) {
|
||||||
offload_func_nr = ggml_cuda_assign_buffers;
|
offload_func_nr = ggml_cuda_assign_buffers_no_alloc;
|
||||||
}
|
}
|
||||||
if (n_gpu_layers > n_layer + 1) {
|
if (n_gpu_layers > n_layer + 1) {
|
||||||
offload_func_v = ggml_cuda_assign_buffers;
|
offload_func_v = ggml_cuda_assign_buffers_no_alloc;
|
||||||
}
|
}
|
||||||
if (n_gpu_layers > n_layer + 2) {
|
if (n_gpu_layers > n_layer + 2) {
|
||||||
offload_func_kq = ggml_cuda_assign_buffers;
|
offload_func_kq = ggml_cuda_assign_buffers_no_alloc;
|
||||||
}
|
}
|
||||||
#endif // GGML_USE_CUBLAS
|
#endif // GGML_USE_CUBLAS
|
||||||
|
|
||||||
struct ggml_tensor * KQ_scale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1);
|
struct ggml_tensor * KQ_scale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1);
|
||||||
#ifdef LLAMA_USE_ALLOCATOR
|
|
||||||
ggml_allocr_alloc(lctx.alloc, KQ_scale);
|
ggml_allocr_alloc(lctx.alloc, KQ_scale);
|
||||||
if (!ggml_allocr_is_measure(lctx.alloc)) {
|
if (!ggml_allocr_is_measure(lctx.alloc)) {
|
||||||
ggml_set_f32(KQ_scale, 1.0f/sqrtf(float(n_embd)/n_head));
|
ggml_set_f32(KQ_scale, 1.0f/sqrtf(float(n_embd)/n_head));
|
||||||
}
|
}
|
||||||
#else
|
|
||||||
ggml_set_f32(KQ_scale, 1.0f/sqrtf(float(n_embd)/n_head));
|
|
||||||
#endif
|
|
||||||
ggml_set_name(KQ_scale, "1/sqrt(n_embd_head)");
|
ggml_set_name(KQ_scale, "1/sqrt(n_embd_head)");
|
||||||
|
|
||||||
for (int il = 0; il < n_layer; ++il) {
|
for (int il = 0; il < n_layer; ++il) {
|
||||||
|
@ -1959,14 +1794,12 @@ static struct ggml_cgraph * llama_build_graph(
|
||||||
|
|
||||||
#ifdef GGML_USE_CUBLAS
|
#ifdef GGML_USE_CUBLAS
|
||||||
if (il >= i_gpu_start) {
|
if (il >= i_gpu_start) {
|
||||||
offload_func = ggml_cuda_assign_buffers;
|
offload_func = ggml_cuda_assign_buffers_no_alloc;
|
||||||
}
|
}
|
||||||
#endif // GGML_USE_CUBLAS
|
#endif // GGML_USE_CUBLAS
|
||||||
|
|
||||||
struct ggml_tensor * inpSA = inpL;
|
struct ggml_tensor * inpSA = inpL;
|
||||||
|
|
||||||
lctx.use_buf(ctx0, 0);
|
|
||||||
|
|
||||||
// norm
|
// norm
|
||||||
{
|
{
|
||||||
cur = ggml_rms_norm(ctx0, inpL, norm_rms_eps);
|
cur = ggml_rms_norm(ctx0, inpL, norm_rms_eps);
|
||||||
|
@ -2104,8 +1937,6 @@ static struct ggml_cgraph * llama_build_graph(
|
||||||
ggml_set_name(cur, "result_wo");
|
ggml_set_name(cur, "result_wo");
|
||||||
}
|
}
|
||||||
|
|
||||||
lctx.use_buf(ctx0, 1);
|
|
||||||
|
|
||||||
struct ggml_tensor * inpFF = ggml_add(ctx0, cur, inpSA);
|
struct ggml_tensor * inpFF = ggml_add(ctx0, cur, inpSA);
|
||||||
offload_func(inpFF);
|
offload_func(inpFF);
|
||||||
ggml_set_name(inpFF, "inpFF");
|
ggml_set_name(inpFF, "inpFF");
|
||||||
|
@ -2160,8 +1991,6 @@ static struct ggml_cgraph * llama_build_graph(
|
||||||
inpL = cur;
|
inpL = cur;
|
||||||
}
|
}
|
||||||
|
|
||||||
lctx.use_buf(ctx0, 0);
|
|
||||||
|
|
||||||
// norm
|
// norm
|
||||||
{
|
{
|
||||||
cur = ggml_rms_norm(ctx0, inpL, norm_rms_eps);
|
cur = ggml_rms_norm(ctx0, inpL, norm_rms_eps);
|
||||||
|
@ -2178,8 +2007,6 @@ static struct ggml_cgraph * llama_build_graph(
|
||||||
cur = ggml_mul_mat(ctx0, model.output, cur);
|
cur = ggml_mul_mat(ctx0, model.output, cur);
|
||||||
ggml_set_name(cur, "result_output");
|
ggml_set_name(cur, "result_output");
|
||||||
|
|
||||||
lctx.use_buf(ctx0, -1);
|
|
||||||
|
|
||||||
// logits -> probs
|
// logits -> probs
|
||||||
//cur = ggml_soft_max_inplace(ctx0, cur);
|
//cur = ggml_soft_max_inplace(ctx0, cur);
|
||||||
|
|
||||||
|
@ -2189,15 +2016,6 @@ static struct ggml_cgraph * llama_build_graph(
|
||||||
mem_per_token = ggml_used_mem(ctx0)/N;
|
mem_per_token = ggml_used_mem(ctx0)/N;
|
||||||
}
|
}
|
||||||
|
|
||||||
#if 0
|
|
||||||
LLAMA_LOG_INFO("\n%s: used_mem: eval ctx %.3f MB, scratch %.3f MB %.3f MB, work buf %.3f MB, n_past = %d, N = %d\n", __func__,
|
|
||||||
ggml_used_mem(ctx0)/1024.0/1024.0,
|
|
||||||
lctx.get_buf_max_mem(0)/1024.0/1024.0,
|
|
||||||
lctx.get_buf_max_mem(1)/1024.0/1024.0,
|
|
||||||
lctx.work_buffer.size()/1024.0/1024.0,
|
|
||||||
n_past, N);
|
|
||||||
#endif
|
|
||||||
|
|
||||||
ggml_free(ctx0);
|
ggml_free(ctx0);
|
||||||
|
|
||||||
return gf;
|
return gf;
|
||||||
|
@ -2248,14 +2066,26 @@ static bool llama_eval_internal(
|
||||||
const int64_t n_embd = hparams.n_embd;
|
const int64_t n_embd = hparams.n_embd;
|
||||||
const int64_t n_vocab = hparams.n_vocab;
|
const int64_t n_vocab = hparams.n_vocab;
|
||||||
|
|
||||||
#ifdef LLAMA_USE_ALLOCATOR
|
|
||||||
ggml_allocr_reset(lctx.alloc);
|
ggml_allocr_reset(lctx.alloc);
|
||||||
#endif
|
|
||||||
|
|
||||||
ggml_cgraph * gf = llama_build_graph(lctx, tokens, embd, n_tokens, n_past);
|
ggml_cgraph * gf = llama_build_graph(lctx, tokens, embd, n_tokens, n_past);
|
||||||
|
|
||||||
#ifdef LLAMA_USE_ALLOCATOR
|
|
||||||
ggml_allocr_alloc_graph(lctx.alloc, gf);
|
ggml_allocr_alloc_graph(lctx.alloc, gf);
|
||||||
|
|
||||||
|
#ifdef GGML_USE_CUBLAS
|
||||||
|
for (int i = 0; i < gf->n_leafs; i++) {
|
||||||
|
ggml_tensor * node = gf->leafs[i];
|
||||||
|
if (node->backend == GGML_BACKEND_GPU && node->extra == NULL) {
|
||||||
|
ggml_cuda_assign_scratch_offset(node, (char*)node->data - (char *) lctx.buf_alloc.data);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int i = 0; i < gf->n_nodes; i++) {
|
||||||
|
ggml_tensor * node = gf->nodes[i];
|
||||||
|
if (node->backend == GGML_BACKEND_GPU && node->extra == NULL) {
|
||||||
|
ggml_cuda_assign_scratch_offset(node, (char*)node->data - (char *) lctx.buf_alloc.data);
|
||||||
|
}
|
||||||
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
// LLAMA_LOG_INFO("graph build time: %.3f ms (%d nodes, %d leafs)\n", (ggml_time_us() - t_start_us)/1000.0, gf->n_nodes, gf->n_leafs);
|
// LLAMA_LOG_INFO("graph build time: %.3f ms (%d nodes, %d leafs)\n", (ggml_time_us() - t_start_us)/1000.0, gf->n_nodes, gf->n_leafs);
|
||||||
|
@ -4319,7 +4149,6 @@ struct llama_context * llama_new_context_with_model(
|
||||||
ctx->embedding.resize(hparams.n_embd);
|
ctx->embedding.resize(hparams.n_embd);
|
||||||
}
|
}
|
||||||
|
|
||||||
#ifdef LLAMA_USE_ALLOCATOR
|
|
||||||
{
|
{
|
||||||
static const size_t tensor_alignment = 32;
|
static const size_t tensor_alignment = 32;
|
||||||
// the compute buffer is used to store the tensor and graph structs, while the allocator buffer is used for the tensor data
|
// the compute buffer is used to store the tensor and graph structs, while the allocator buffer is used for the tensor data
|
||||||
|
@ -4350,13 +4179,6 @@ struct llama_context * llama_new_context_with_model(
|
||||||
|
|
||||||
LLAMA_LOG_INFO("%s: compute buffer total size = %7.2f MB\n", __func__, (ctx->buf_compute.size + alloc_size) / 1024.0 / 1024.0);
|
LLAMA_LOG_INFO("%s: compute buffer total size = %7.2f MB\n", __func__, (ctx->buf_compute.size + alloc_size) / 1024.0 / 1024.0);
|
||||||
|
|
||||||
// debug - for comparison with scratch buffer
|
|
||||||
//size_t prev_req =
|
|
||||||
// MEM_REQ_SCRATCH0(hparams.n_ctx).at(ctx->model.type) +
|
|
||||||
// MEM_REQ_SCRATCH1().at(ctx->model.type) +
|
|
||||||
// MEM_REQ_EVAL().at(ctx->model.type);
|
|
||||||
//LLAMA_LOG_INFO("%s: (debug) equivalent with scratch buffer = %7.2f MB\n", __func__, prev_req / 1024.0 / 1024.0);
|
|
||||||
|
|
||||||
// recreate allocator with exact memory requirements
|
// recreate allocator with exact memory requirements
|
||||||
ggml_allocr_free(ctx->alloc);
|
ggml_allocr_free(ctx->alloc);
|
||||||
|
|
||||||
|
@ -4367,16 +4189,17 @@ struct llama_context * llama_new_context_with_model(
|
||||||
ggml_allocr_set_parse_seq(ctx->alloc, ggml_metal_get_concur_list(ctx->ctx_metal), ggml_metal_if_optimized(ctx->ctx_metal));
|
ggml_allocr_set_parse_seq(ctx->alloc, ggml_metal_get_concur_list(ctx->ctx_metal), ggml_metal_if_optimized(ctx->ctx_metal));
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
#ifdef GGML_USE_CUBLAS
|
||||||
|
if (params.low_vram) {
|
||||||
|
LLAMA_LOG_INFO("%s: not allocating a VRAM scratch buffer due to low VRAM option\n", __func__);
|
||||||
|
ggml_cuda_set_scratch_size(0); // disable scratch
|
||||||
|
} else {
|
||||||
|
ggml_cuda_set_scratch_size(alloc_size);
|
||||||
|
LLAMA_LOG_INFO("%s: VRAM scratch buffer: %.2f MB\n", __func__, alloc_size / 1024.0 / 1024.0);
|
||||||
}
|
}
|
||||||
#else
|
|
||||||
ctx->buf_compute.resize(MEM_REQ_EVAL().at(ctx->model.type) + ggml_graph_overhead());
|
|
||||||
#endif
|
|
||||||
|
|
||||||
#ifdef LLAMA_USE_SCRATCH
|
|
||||||
ctx->buf_scratch[0].resize(MEM_REQ_SCRATCH0(hparams.n_ctx).at(ctx->model.type));
|
|
||||||
ctx->buf_scratch[1].resize(MEM_REQ_SCRATCH1().at(ctx->model.type));
|
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#ifdef GGML_USE_METAL
|
#ifdef GGML_USE_METAL
|
||||||
if (params.n_gpu_layers > 0) {
|
if (params.n_gpu_layers > 0) {
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
#!/bin/bash
|
#!/bin/bash
|
||||||
|
|
||||||
cp -rpv ../ggml/src/ggml.c ./ggml.c
|
cp -rpv ../ggml/src/ggml.c ./ggml.c
|
||||||
|
cp -rpv ../ggml/src/ggml-alloc.c ./ggml-alloc.c
|
||||||
cp -rpv ../ggml/src/ggml-cuda.h ./ggml-cuda.h
|
cp -rpv ../ggml/src/ggml-cuda.h ./ggml-cuda.h
|
||||||
cp -rpv ../ggml/src/ggml-cuda.cu ./ggml-cuda.cu
|
cp -rpv ../ggml/src/ggml-cuda.cu ./ggml-cuda.cu
|
||||||
cp -rpv ../ggml/src/ggml-opencl.h ./ggml-opencl.h
|
cp -rpv ../ggml/src/ggml-opencl.h ./ggml-opencl.h
|
||||||
|
@ -9,6 +10,7 @@ cp -rpv ../ggml/src/ggml-metal.h ./ggml-metal.h
|
||||||
cp -rpv ../ggml/src/ggml-metal.m ./ggml-metal.m
|
cp -rpv ../ggml/src/ggml-metal.m ./ggml-metal.m
|
||||||
cp -rpv ../ggml/src/ggml-metal.metal ./ggml-metal.metal
|
cp -rpv ../ggml/src/ggml-metal.metal ./ggml-metal.metal
|
||||||
cp -rpv ../ggml/include/ggml/ggml.h ./ggml.h
|
cp -rpv ../ggml/include/ggml/ggml.h ./ggml.h
|
||||||
|
cp -rpv ../ggml/include/ggml/ggml-alloc.h ./ggml-alloc.h
|
||||||
|
|
||||||
cp -rpv ../ggml/tests/test-opt.cpp ./tests/test-opt.cpp
|
cp -rpv ../ggml/tests/test-opt.cpp ./tests/test-opt.cpp
|
||||||
cp -rpv ../ggml/tests/test-grad0.cpp ./tests/test-grad0.cpp
|
cp -rpv ../ggml/tests/test-grad0.cpp ./tests/test-grad0.cpp
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue