Merge branch 'master' of https://github.com/ggerganov/llama.cpp into clang-warnings
This commit is contained in:
commit
a7d13ac15b
7 changed files with 77 additions and 54 deletions
|
@ -11,6 +11,8 @@ Inference of [LLaMA](https://arxiv.org/abs/2302.13971) model in pure C/C++
|
||||||
|
|
||||||
### Hot topics
|
### Hot topics
|
||||||
|
|
||||||
|
- Parallel decoding + continuous batching support incoming: [#3228](https://github.com/ggerganov/llama.cpp/pull/3228) \
|
||||||
|
**Devs should become familiar with the new API**
|
||||||
- Local Falcon 180B inference on Mac Studio
|
- Local Falcon 180B inference on Mac Studio
|
||||||
|
|
||||||
https://github.com/ggerganov/llama.cpp/assets/1991296/98abd4e8-7077-464c-ae89-aebabca7757e
|
https://github.com/ggerganov/llama.cpp/assets/1991296/98abd4e8-7077-464c-ae89-aebabca7757e
|
||||||
|
@ -555,6 +557,10 @@ python3 convert.py models/7B/
|
||||||
# quantize the model to 4-bits (using q4_0 method)
|
# quantize the model to 4-bits (using q4_0 method)
|
||||||
./quantize ./models/7B/ggml-model-f16.gguf ./models/7B/ggml-model-q4_0.gguf q4_0
|
./quantize ./models/7B/ggml-model-f16.gguf ./models/7B/ggml-model-q4_0.gguf q4_0
|
||||||
|
|
||||||
|
# update the gguf filetype to current if older version is unsupported by another application
|
||||||
|
./quantize ./models/7B/ggml-model-q4_0.gguf ./models/7B/ggml-model-q4_0-v2.gguf COPY
|
||||||
|
|
||||||
|
|
||||||
# run the inference
|
# run the inference
|
||||||
./main -m ./models/7B/ggml-model-q4_0.gguf -n 128
|
./main -m ./models/7B/ggml-model-q4_0.gguf -n 128
|
||||||
```
|
```
|
||||||
|
|
22
build.zig
22
build.zig
|
@ -36,17 +36,20 @@ const Maker = struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
fn init(builder: *std.build.Builder) !Maker {
|
fn init(builder: *std.build.Builder) !Maker {
|
||||||
const commit_hash = @embedFile(".git/refs/heads/master");
|
// const commit_hash = @embedFile(".git/refs/heads/master");
|
||||||
|
const target = builder.standardTargetOptions(.{});
|
||||||
const config_header = builder.addConfigHeader(
|
const config_header = builder.addConfigHeader(
|
||||||
.{ .style = .blank, .include_path = "build-info.h" },
|
.{ .style = .blank, .include_path = "build-info.h" },
|
||||||
.{
|
.{
|
||||||
.BUILD_NUMBER = 0,
|
.BUILD_NUMBER = 0,
|
||||||
.BUILD_COMMIT = commit_hash[0 .. commit_hash.len - 1], // omit newline
|
.BUILD_COMMIT = "12345", // omit newline
|
||||||
|
.BUILD_COMPILER = "Zig 0.11.0",
|
||||||
|
.BUILD_TARGET = try target.allocDescription(builder.allocator),
|
||||||
},
|
},
|
||||||
);
|
);
|
||||||
var m = Maker{
|
var m = Maker{
|
||||||
.builder = builder,
|
.builder = builder,
|
||||||
.target = builder.standardTargetOptions(.{}),
|
.target = target,
|
||||||
.optimize = builder.standardOptimizeOption(.{}),
|
.optimize = builder.standardOptimizeOption(.{}),
|
||||||
.config_header = config_header,
|
.config_header = config_header,
|
||||||
.enable_lto = false,
|
.enable_lto = false,
|
||||||
|
@ -58,7 +61,7 @@ const Maker = struct {
|
||||||
try m.addCFlag("-std=c11");
|
try m.addCFlag("-std=c11");
|
||||||
try m.addCxxFlag("-std=c++11");
|
try m.addCxxFlag("-std=c++11");
|
||||||
try m.addProjectInclude(&.{});
|
try m.addProjectInclude(&.{});
|
||||||
try m.addProjectInclude(&.{"examples"});
|
try m.addProjectInclude(&.{"common"});
|
||||||
return m;
|
return m;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -71,6 +74,7 @@ const Maker = struct {
|
||||||
o.addCSourceFiles(&.{src}, m.cxxflags.items);
|
o.addCSourceFiles(&.{src}, m.cxxflags.items);
|
||||||
o.linkLibCpp();
|
o.linkLibCpp();
|
||||||
}
|
}
|
||||||
|
o.addConfigHeader(m.config_header);
|
||||||
for (m.include_dirs.items) |i| o.addIncludePath(.{ .path = i });
|
for (m.include_dirs.items) |i| o.addIncludePath(.{ .path = i });
|
||||||
o.want_lto = m.enable_lto;
|
o.want_lto = m.enable_lto;
|
||||||
return o;
|
return o;
|
||||||
|
@ -104,15 +108,15 @@ pub fn build(b: *std.build.Builder) !void {
|
||||||
const ggml = make.obj("ggml", "ggml.c");
|
const ggml = make.obj("ggml", "ggml.c");
|
||||||
const ggml_alloc = make.obj("ggml-alloc", "ggml-alloc.c");
|
const ggml_alloc = make.obj("ggml-alloc", "ggml-alloc.c");
|
||||||
const llama = make.obj("llama", "llama.cpp");
|
const llama = make.obj("llama", "llama.cpp");
|
||||||
const common = make.obj("common", "examples/common.cpp");
|
const common = make.obj("common", "common/common.cpp");
|
||||||
const console = make.obj("common", "examples/console.cpp");
|
const console = make.obj("common", "common/console.cpp");
|
||||||
const grammar_parser = make.obj("grammar-parser", "examples/grammar-parser.cpp");
|
const grammar_parser = make.obj("grammar-parser", "common/grammar-parser.cpp");
|
||||||
|
|
||||||
_ = make.exe("main", "examples/main/main.cpp", &.{ ggml, ggml_alloc, llama, common, console, grammar_parser });
|
_ = make.exe("main", "examples/main/main.cpp", &.{ ggml, ggml_alloc, llama, common, console, grammar_parser });
|
||||||
_ = make.exe("quantize", "examples/quantize/quantize.cpp", &.{ ggml, ggml_alloc, llama });
|
_ = make.exe("quantize", "examples/quantize/quantize.cpp", &.{ ggml, ggml_alloc, llama, common });
|
||||||
_ = make.exe("perplexity", "examples/perplexity/perplexity.cpp", &.{ ggml, ggml_alloc, llama, common });
|
_ = make.exe("perplexity", "examples/perplexity/perplexity.cpp", &.{ ggml, ggml_alloc, llama, common });
|
||||||
_ = make.exe("embedding", "examples/embedding/embedding.cpp", &.{ ggml, ggml_alloc, llama, common });
|
_ = make.exe("embedding", "examples/embedding/embedding.cpp", &.{ ggml, ggml_alloc, llama, common });
|
||||||
_ = make.exe("train-text-from-scratch", "examples/train-text-from-scratch/train-text-from-scratch.cpp", &.{ ggml, ggml_alloc, llama });
|
_ = make.exe("train-text-from-scratch", "examples/train-text-from-scratch/train-text-from-scratch.cpp", &.{ ggml, ggml_alloc, llama, common });
|
||||||
|
|
||||||
const server = make.exe("server", "examples/server/server.cpp", &.{ ggml, ggml_alloc, llama, common, grammar_parser });
|
const server = make.exe("server", "examples/server/server.cpp", &.{ ggml, ggml_alloc, llama, common, grammar_parser });
|
||||||
if (server.target.isWindows()) {
|
if (server.target.isWindows()) {
|
||||||
|
|
|
@ -647,9 +647,9 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
|
||||||
printf(" --cfg-negative-prompt-file FNAME\n");
|
printf(" --cfg-negative-prompt-file FNAME\n");
|
||||||
printf(" negative prompt file to use for guidance. (default: empty)\n");
|
printf(" negative prompt file to use for guidance. (default: empty)\n");
|
||||||
printf(" --cfg-scale N strength of guidance (default: %f, 1.0 = disable)\n", params.cfg_scale);
|
printf(" --cfg-scale N strength of guidance (default: %f, 1.0 = disable)\n", params.cfg_scale);
|
||||||
printf(" --rope-scale N RoPE context linear scaling factor, inverse of --rope-freq-scale (default: %g)\n", 1.0f/params.rope_freq_scale);
|
printf(" --rope-scale N RoPE context linear scaling factor, inverse of --rope-freq-scale\n");
|
||||||
printf(" --rope-freq-base N RoPE base frequency, used by NTK-aware scaling (default: %.1f)\n", params.rope_freq_base);
|
printf(" --rope-freq-base N RoPE base frequency, used by NTK-aware scaling (default: loaded from model)\n");
|
||||||
printf(" --rope-freq-scale N RoPE frequency linear scaling factor, inverse of --rope-scale (default: %g)\n", params.rope_freq_scale);
|
printf(" --rope-freq-scale N RoPE frequency linear scaling factor (default: loaded from model)\n");
|
||||||
printf(" --ignore-eos ignore end of stream token and continue generating (implies --logit-bias 2-inf)\n");
|
printf(" --ignore-eos ignore end of stream token and continue generating (implies --logit-bias 2-inf)\n");
|
||||||
printf(" --no-penalize-nl do not penalize newline token\n");
|
printf(" --no-penalize-nl do not penalize newline token\n");
|
||||||
printf(" --memory-f32 use f32 instead of f16 for memory key+value (default: disabled)\n");
|
printf(" --memory-f32 use f32 instead of f16 for memory key+value (default: disabled)\n");
|
||||||
|
|
|
@ -1,3 +1,21 @@
|
||||||
# embedding
|
# llama.cpp/example/embedding
|
||||||
|
|
||||||
TODO
|
This example demonstrates generate high-dimensional embedding vector of a given text with llama.cpp.
|
||||||
|
|
||||||
|
## Quick Start
|
||||||
|
|
||||||
|
To get started right away, run the following command, making sure to use the correct path for the model you have:
|
||||||
|
|
||||||
|
### Unix-based systems (Linux, macOS, etc.):
|
||||||
|
|
||||||
|
```bash
|
||||||
|
./embedding -m ./path/to/model --log-disable -p "Hello World!" 2>/dev/null
|
||||||
|
```
|
||||||
|
|
||||||
|
### Windows:
|
||||||
|
|
||||||
|
```powershell
|
||||||
|
embedding.exe -m ./path/to/model --log-disable -p "Hello World!" 2>$null
|
||||||
|
```
|
||||||
|
|
||||||
|
The above command will output space-separated float values.
|
||||||
|
|
|
@ -701,8 +701,8 @@ static void server_print_usage(const char *argv0, const gpt_params ¶ms,
|
||||||
printf(" -v, --verbose verbose output (default: %s)\n", server_verbose ? "enabled" : "disabled");
|
printf(" -v, --verbose verbose output (default: %s)\n", server_verbose ? "enabled" : "disabled");
|
||||||
printf(" -t N, --threads N number of threads to use during computation (default: %d)\n", params.n_threads);
|
printf(" -t N, --threads N number of threads to use during computation (default: %d)\n", params.n_threads);
|
||||||
printf(" -c N, --ctx-size N size of the prompt context (default: %d)\n", params.n_ctx);
|
printf(" -c N, --ctx-size N size of the prompt context (default: %d)\n", params.n_ctx);
|
||||||
printf(" --rope-freq-base N RoPE base frequency (default: %.1f)\n", params.rope_freq_base);
|
printf(" --rope-freq-base N RoPE base frequency (default: loaded from model)\n");
|
||||||
printf(" --rope-freq-scale N RoPE frequency scaling factor (default: %g)\n", params.rope_freq_scale);
|
printf(" --rope-freq-scale N RoPE frequency scaling factor (default: loaded from model)\n");
|
||||||
printf(" -b N, --batch-size N batch size for prompt processing (default: %d)\n", params.n_batch);
|
printf(" -b N, --batch-size N batch size for prompt processing (default: %d)\n", params.n_batch);
|
||||||
printf(" --memory-f32 use f32 instead of f16 for memory key+value (default: disabled)\n");
|
printf(" --memory-f32 use f32 instead of f16 for memory key+value (default: disabled)\n");
|
||||||
printf(" not recommended: doubles context memory required and no measurable increase in quality\n");
|
printf(" not recommended: doubles context memory required and no measurable increase in quality\n");
|
||||||
|
|
|
@ -1788,9 +1788,7 @@ bool ggml_cl_can_mul_mat(const struct ggml_tensor * src0, const struct ggml_tens
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
static bool ggml_cl_mul_mat_use_f16(
|
static bool ggml_cl_mul_mat_use_f16(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * /* dst */) {
|
||||||
const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * /* dst */
|
|
||||||
) {
|
|
||||||
// If device doesn't support FP16
|
// If device doesn't support FP16
|
||||||
if (!fp16_support) {
|
if (!fp16_support) {
|
||||||
return false;
|
return false;
|
||||||
|
|
63
llama.cpp
63
llama.cpp
|
@ -929,23 +929,22 @@ static const size_t kB = 1024;
|
||||||
static const size_t MB = kB*kB;
|
static const size_t MB = kB*kB;
|
||||||
static const size_t GB = kB*kB*kB;
|
static const size_t GB = kB*kB*kB;
|
||||||
|
|
||||||
// default hparams (LLaMA 7B)
|
|
||||||
struct llama_hparams {
|
struct llama_hparams {
|
||||||
uint32_t n_vocab = 32000;
|
uint32_t n_vocab;
|
||||||
uint32_t n_ctx_train = 2048; // the context size used during training
|
uint32_t n_ctx_train; // context size the model was trained on
|
||||||
uint32_t n_ctx = 512; // the context size used during inference
|
uint32_t n_ctx; // context size used during inference
|
||||||
uint32_t n_embd = 4096;
|
uint32_t n_embd;
|
||||||
uint32_t n_head = 32;
|
uint32_t n_head;
|
||||||
uint32_t n_head_kv = 32;
|
uint32_t n_head_kv;
|
||||||
uint32_t n_layer = 32;
|
uint32_t n_layer;
|
||||||
uint32_t n_rot = 64;
|
uint32_t n_rot;
|
||||||
uint32_t n_ff = 11008;
|
uint32_t n_ff;
|
||||||
|
|
||||||
float f_norm_eps = 1e-5;
|
float f_norm_eps;
|
||||||
float f_norm_rms_eps = 1e-5;
|
float f_norm_rms_eps;
|
||||||
|
|
||||||
float rope_freq_base = 10000.0f;
|
float rope_freq_base;
|
||||||
float rope_freq_scale = 1.0f;
|
float rope_freq_scale;
|
||||||
|
|
||||||
bool operator!=(const llama_hparams & other) const {
|
bool operator!=(const llama_hparams & other) const {
|
||||||
return static_cast<bool>(memcmp(this, &other, sizeof(llama_hparams))); // NOLINT
|
return static_cast<bool>(memcmp(this, &other, sizeof(llama_hparams))); // NOLINT
|
||||||
|
@ -1076,7 +1075,7 @@ struct llama_model {
|
||||||
|
|
||||||
std::string name = "n/a";
|
std::string name = "n/a";
|
||||||
|
|
||||||
llama_hparams hparams;
|
llama_hparams hparams = {};
|
||||||
llama_vocab vocab;
|
llama_vocab vocab;
|
||||||
|
|
||||||
struct ggml_tensor * tok_embeddings;
|
struct ggml_tensor * tok_embeddings;
|
||||||
|
@ -1674,29 +1673,18 @@ static void llm_load_hparams(
|
||||||
hparams.n_head_kv = hparams.n_head;
|
hparams.n_head_kv = hparams.n_head;
|
||||||
GGUF_GET_KEY(ctx, hparams.n_head_kv, gguf_get_val_u32, GGUF_TYPE_UINT32, false, kv(LLM_KV_ATTENTION_HEAD_COUNT_KV));
|
GGUF_GET_KEY(ctx, hparams.n_head_kv, gguf_get_val_u32, GGUF_TYPE_UINT32, false, kv(LLM_KV_ATTENTION_HEAD_COUNT_KV));
|
||||||
|
|
||||||
// TODO: manually setting rope freq base and scale should override this
|
// rope_freq_base (optional)
|
||||||
// FIXME: partial fix when the param specified is not the default value, but
|
if (rope_freq_base == 0.0f) {
|
||||||
// will not work for overriding the model value to the params default
|
rope_freq_base = 10000.0f;
|
||||||
|
GGUF_GET_KEY(ctx, rope_freq_base, gguf_get_val_f32, GGUF_TYPE_FLOAT32, false, kv(LLM_KV_ROPE_FREQ_BASE));
|
||||||
llama_context_params defaults = llama_context_default_params();
|
|
||||||
|
|
||||||
// rope_freq_base
|
|
||||||
{
|
|
||||||
float ropebase = 10000.0f;
|
|
||||||
GGUF_GET_KEY(ctx, ropebase, gguf_get_val_f32, GGUF_TYPE_FLOAT32, false, kv(LLM_KV_ROPE_FREQ_BASE));
|
|
||||||
if (ropebase != 10000.0f && rope_freq_base == defaults.rope_freq_base) {
|
|
||||||
rope_freq_base = ropebase;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// rope_freq_scale (inverse of the kv) is optional
|
// rope_freq_scale (inverse of the kv) is optional
|
||||||
{
|
if (rope_freq_scale == 0.0f) {
|
||||||
float ropescale = 1.0f;
|
float ropescale = 1.0f;
|
||||||
GGUF_GET_KEY(ctx, ropescale, gguf_get_val_f32, GGUF_TYPE_FLOAT32, false, kv(LLM_KV_ROPE_SCALE_LINEAR));
|
GGUF_GET_KEY(ctx, ropescale, gguf_get_val_f32, GGUF_TYPE_FLOAT32, false, kv(LLM_KV_ROPE_SCALE_LINEAR));
|
||||||
if (ropescale != 1.0f && rope_freq_scale == defaults.rope_freq_scale) {
|
|
||||||
rope_freq_scale = 1.0f/ropescale;
|
rope_freq_scale = 1.0f/ropescale;
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
// sanity check for n_rot (optional)
|
// sanity check for n_rot (optional)
|
||||||
{
|
{
|
||||||
|
@ -3777,6 +3765,15 @@ static bool llama_eval_internal(
|
||||||
n_threads = std::min(4, n_threads);
|
n_threads = std::min(4, n_threads);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// If all tensors can be run on the GPU then using more than 1 thread is detrimental.
|
||||||
|
const bool full_offload_supported = model.arch == LLM_ARCH_LLAMA ||
|
||||||
|
model.arch == LLM_ARCH_BAICHUAN ||
|
||||||
|
model.arch == LLM_ARCH_FALCON;
|
||||||
|
const bool fully_offloaded = model.n_gpu_layers >= (int) hparams.n_layer + 3;
|
||||||
|
if (ggml_cpu_has_cublas() && full_offload_supported && fully_offloaded) {
|
||||||
|
n_threads = 1;
|
||||||
|
}
|
||||||
|
|
||||||
struct ggml_tensor * res = gf->nodes[gf->n_nodes - 1];
|
struct ggml_tensor * res = gf->nodes[gf->n_nodes - 1];
|
||||||
struct ggml_tensor * embeddings = gf->nodes[gf->n_nodes - 2];
|
struct ggml_tensor * embeddings = gf->nodes[gf->n_nodes - 2];
|
||||||
|
|
||||||
|
@ -6188,8 +6185,8 @@ struct llama_context_params llama_context_default_params() {
|
||||||
/*.n_gpu_layers =*/ 0,
|
/*.n_gpu_layers =*/ 0,
|
||||||
/*.main_gpu =*/ 0,
|
/*.main_gpu =*/ 0,
|
||||||
/*.tensor_split =*/ nullptr,
|
/*.tensor_split =*/ nullptr,
|
||||||
/*.rope_freq_base =*/ 10000.0f,
|
/*.rope_freq_base =*/ 0.0f,
|
||||||
/*.rope_freq_scale =*/ 1.0f,
|
/*.rope_freq_scale =*/ 0.0f,
|
||||||
/*.progress_callback =*/ nullptr,
|
/*.progress_callback =*/ nullptr,
|
||||||
/*.progress_callback_user_data =*/ nullptr,
|
/*.progress_callback_user_data =*/ nullptr,
|
||||||
/*.low_vram =*/ false,
|
/*.low_vram =*/ false,
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue