Merge branch 'master' of https://github.com/JoanFM/llama.cpp into feat-jina-embeddings

This commit is contained in:
Joan Martinez 2024-05-09 09:42:38 +02:00
commit 849aeda215
45 changed files with 3469 additions and 1444 deletions

View file

@ -405,6 +405,7 @@ if (LLAMA_CUDA)
list(APPEND GGML_SOURCES_CUDA "ggml-cuda.cu") list(APPEND GGML_SOURCES_CUDA "ggml-cuda.cu")
add_compile_definitions(GGML_USE_CUDA) add_compile_definitions(GGML_USE_CUDA)
add_compile_definitions(GGML_CUDA_USE_GRAPHS)
if (LLAMA_CUDA_FORCE_DMMV) if (LLAMA_CUDA_FORCE_DMMV)
add_compile_definitions(GGML_CUDA_FORCE_DMMV) add_compile_definitions(GGML_CUDA_FORCE_DMMV)
endif() endif()
@ -430,7 +431,7 @@ if (LLAMA_CUDA)
if (LLAMA_STATIC) if (LLAMA_STATIC)
if (WIN32) if (WIN32)
# As of 12.3.1 CUDA Tookit for Windows does not offer a static cublas library # As of 12.3.1 CUDA Toolkit for Windows does not offer a static cublas library
set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} CUDA::cudart_static CUDA::cublas CUDA::cublasLt) set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} CUDA::cudart_static CUDA::cublas CUDA::cublasLt)
else () else ()
set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} CUDA::cudart_static CUDA::cublas_static CUDA::cublasLt_static) set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} CUDA::cudart_static CUDA::cublas_static CUDA::cublasLt_static)

View file

@ -433,7 +433,7 @@ ifdef LLAMA_CUDA
else else
CUDA_PATH ?= /usr/local/cuda CUDA_PATH ?= /usr/local/cuda
endif endif
MK_CPPFLAGS += -DGGML_USE_CUDA -I$(CUDA_PATH)/include -I$(CUDA_PATH)/targets/$(UNAME_M)-linux/include MK_CPPFLAGS += -DGGML_USE_CUDA -I$(CUDA_PATH)/include -I$(CUDA_PATH)/targets/$(UNAME_M)-linux/include -DGGML_CUDA_USE_GRAPHS
MK_LDFLAGS += -lcuda -lcublas -lculibos -lcudart -lcublasLt -lpthread -ldl -lrt -L$(CUDA_PATH)/lib64 -L/usr/lib64 -L$(CUDA_PATH)/targets/$(UNAME_M)-linux/lib -L/usr/lib/wsl/lib MK_LDFLAGS += -lcuda -lcublas -lculibos -lcudart -lcublasLt -lpthread -ldl -lrt -L$(CUDA_PATH)/lib64 -L/usr/lib64 -L$(CUDA_PATH)/targets/$(UNAME_M)-linux/lib -L/usr/lib/wsl/lib
OBJS += ggml-cuda.o OBJS += ggml-cuda.o
OBJS += $(patsubst %.cu,%.o,$(wildcard ggml-cuda/*.cu)) OBJS += $(patsubst %.cu,%.o,$(wildcard ggml-cuda/*.cu))

View file

@ -140,7 +140,6 @@ Typically finetunes of the base models below are supported as well.
- [x] [MobileVLM 1.7B/3B models](https://huggingface.co/models?search=mobileVLM) - [x] [MobileVLM 1.7B/3B models](https://huggingface.co/models?search=mobileVLM)
- [x] [Yi-VL](https://huggingface.co/models?search=Yi-VL) - [x] [Yi-VL](https://huggingface.co/models?search=Yi-VL)
- [x] [Mini CPM](https://huggingface.co/models?search=MiniCPM) - [x] [Mini CPM](https://huggingface.co/models?search=MiniCPM)
- [x] [Moondream](https://huggingface.co/vikhyatk/moondream2)
**HTTP server** **HTTP server**

View file

@ -1,4 +1,6 @@
#include "common.h" #include "common.h"
// Change JSON_ASSERT from assert() to GGML_ASSERT:
#define JSON_ASSERT GGML_ASSERT
#include "json.hpp" #include "json.hpp"
#include "json-schema-to-grammar.h" #include "json-schema-to-grammar.h"
#include "llama.h" #include "llama.h"
@ -1969,18 +1971,18 @@ static bool llama_download_file(const std::string & url, const std::string & pat
try { try {
metadata_in >> metadata; metadata_in >> metadata;
fprintf(stderr, "%s: previous metadata file found %s: %s\n", __func__, metadata_path.c_str(), metadata.dump().c_str()); fprintf(stderr, "%s: previous metadata file found %s: %s\n", __func__, metadata_path.c_str(), metadata.dump().c_str());
if (metadata.contains("url") && metadata["url"].is_string()) { if (metadata.contains("url") && metadata.at("url").is_string()) {
auto previous_url = metadata["url"].get<std::string>(); auto previous_url = metadata.at("url").get<std::string>();
if (previous_url != url) { if (previous_url != url) {
fprintf(stderr, "%s: Model URL mismatch: %s != %s\n", __func__, url.c_str(), previous_url.c_str()); fprintf(stderr, "%s: Model URL mismatch: %s != %s\n", __func__, url.c_str(), previous_url.c_str());
return false; return false;
} }
} }
if (metadata.contains("etag") && metadata["etag"].is_string()) { if (metadata.contains("etag") && metadata.at("etag").is_string()) {
etag = metadata["etag"]; etag = metadata.at("etag");
} }
if (metadata.contains("lastModified") && metadata["lastModified"].is_string()) { if (metadata.contains("lastModified") && metadata.at("lastModified").is_string()) {
last_modified = metadata["lastModified"]; last_modified = metadata.at("lastModified");
} }
} catch (const nlohmann::json::exception & e) { } catch (const nlohmann::json::exception & e) {
fprintf(stderr, "%s: error reading metadata file %s: %s\n", __func__, metadata_path.c_str(), e.what()); fprintf(stderr, "%s: error reading metadata file %s: %s\n", __func__, metadata_path.c_str(), e.what());

View file

@ -1,4 +1,8 @@
#pragma once #pragma once
#include "ggml.h"
// Change JSON_ASSERT from assert() to GGML_ASSERT:
#define JSON_ASSERT GGML_ASSERT
#include "json.hpp" #include "json.hpp"
std::string json_schema_to_grammar(const nlohmann::ordered_json& schema); std::string json_schema_to_grammar(const nlohmann::ordered_json& schema);

File diff suppressed because it is too large Load diff

View file

@ -284,6 +284,7 @@ class Params:
n_experts = None n_experts = None
n_experts_used = None n_experts_used = None
f_rope_freq_base = None f_rope_freq_base = None
n_ff = None
# hack to determine LLaMA v1 vs v2 vs CodeLlama # hack to determine LLaMA v1 vs v2 vs CodeLlama
if config.get("moe"): if config.get("moe"):
@ -308,6 +309,8 @@ class Params:
n_experts_used = config["moe"]["num_experts_per_tok"] n_experts_used = config["moe"]["num_experts_per_tok"]
f_rope_freq_base = 1e6 f_rope_freq_base = 1e6
assert n_ff is not None
return Params( return Params(
n_vocab = model["tok_embeddings.weight"].shape[0], n_vocab = model["tok_embeddings.weight"].shape[0],
n_embd = config["dim"], n_embd = config["dim"],
@ -462,7 +465,8 @@ class SentencePieceVocab(Vocab):
# not found in alternate location either # not found in alternate location either
raise FileNotFoundError('Cannot find tokenizer.model') raise FileNotFoundError('Cannot find tokenizer.model')
self.sentencepiece_tokenizer = SentencePieceProcessor(str(fname_tokenizer)) self.sentencepiece_tokenizer = SentencePieceProcessor()
self.sentencepiece_tokenizer.LoadFromFile(str(fname_tokenizer))
vocab_size = self.sentencepiece_tokenizer.vocab_size() vocab_size = self.sentencepiece_tokenizer.vocab_size()
new_tokens = {id: piece for piece, id in added_tokens.items() if id >= vocab_size} new_tokens = {id: piece for piece, id in added_tokens.items() if id >= vocab_size}
@ -482,23 +486,23 @@ class SentencePieceVocab(Vocab):
def sentencepiece_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]: def sentencepiece_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]:
tokenizer = self.sentencepiece_tokenizer tokenizer = self.sentencepiece_tokenizer
for i in range(tokenizer.vocab_size()): for i in range(tokenizer.vocab_size()):
piece = tokenizer.id_to_piece(i) piece = tokenizer.IdToPiece(i)
text = piece.encode("utf-8") text = piece.encode("utf-8")
score: float = tokenizer.get_score(i) score: float = tokenizer.GetScore(i)
toktype = gguf.TokenType.NORMAL toktype = gguf.TokenType.NORMAL
if tokenizer.is_unknown(i): if tokenizer.IsUnknown(i):
toktype = gguf.TokenType.UNKNOWN toktype = gguf.TokenType.UNKNOWN
if tokenizer.is_control(i): if tokenizer.IsControl(i):
toktype = gguf.TokenType.CONTROL toktype = gguf.TokenType.CONTROL
# NOTE: I think added_tokens are user defined. # NOTE: I think added_tokens are user defined.
# ref: https://github.com/google/sentencepiece/blob/master/src/sentencepiece_model.proto # ref: https://github.com/google/sentencepiece/blob/master/src/sentencepiece_model.proto
# if tokenizer.is_user_defined(i): toktype = gguf.TokenType.USER_DEFINED # if tokenizer.is_user_defined(i): toktype = gguf.TokenType.USER_DEFINED
if tokenizer.is_unused(i): if tokenizer.IsUnused(i):
toktype = gguf.TokenType.UNUSED toktype = gguf.TokenType.UNUSED
if tokenizer.is_byte(i): if tokenizer.IsByte(i):
toktype = gguf.TokenType.BYTE toktype = gguf.TokenType.BYTE
yield text, score, toktype yield text, score, toktype
@ -906,7 +910,7 @@ class LazyUnpickler(pickle.Unpickler):
def rebuild_from_type_v2(func, new_type, args, state): def rebuild_from_type_v2(func, new_type, args, state):
return func(*args) return func(*args)
CLASSES = { CLASSES: dict[tuple[str, str], type[LazyTensor] | LazyStorageKind] = {
# getattr used here as a workaround for mypy not being smart enough to determine # getattr used here as a workaround for mypy not being smart enough to determine
# the staticmethods have a __func__ attribute. # the staticmethods have a __func__ attribute.
('torch._tensor', '_rebuild_from_type_v2'): getattr(rebuild_from_type_v2, '__func__'), ('torch._tensor', '_rebuild_from_type_v2'): getattr(rebuild_from_type_v2, '__func__'),

View file

@ -104,7 +104,6 @@ static std::string format(const char * fmt, ...) {
#define TN_POS_EMBD "%s.position_embd.weight" #define TN_POS_EMBD "%s.position_embd.weight"
#define TN_CLASS_EMBD "v.class_embd" #define TN_CLASS_EMBD "v.class_embd"
#define TN_PATCH_EMBD "v.patch_embd.weight" #define TN_PATCH_EMBD "v.patch_embd.weight"
#define TN_PATCH_BIAS "v.patch_embd.bias"
#define TN_ATTN_K "%s.blk.%d.attn_k.%s" #define TN_ATTN_K "%s.blk.%d.attn_k.%s"
#define TN_ATTN_Q "%s.blk.%d.attn_q.%s" #define TN_ATTN_Q "%s.blk.%d.attn_q.%s"
#define TN_ATTN_V "%s.blk.%d.attn_v.%s" #define TN_ATTN_V "%s.blk.%d.attn_v.%s"
@ -426,7 +425,6 @@ struct clip_vision_model {
// embeddings // embeddings
struct ggml_tensor * class_embedding; struct ggml_tensor * class_embedding;
struct ggml_tensor * patch_embeddings; struct ggml_tensor * patch_embeddings;
struct ggml_tensor * patch_bias;
struct ggml_tensor * position_embeddings; struct ggml_tensor * position_embeddings;
struct ggml_tensor * pre_ln_w; struct ggml_tensor * pre_ln_w;
@ -503,11 +501,6 @@ struct clip_ctx {
bool use_gelu = false; bool use_gelu = false;
int32_t ftype = 1; int32_t ftype = 1;
bool has_class_embedding = true;
bool has_pre_norm = true;
bool has_post_norm = false;
bool has_patch_bias = false;
struct gguf_context * ctx_gguf; struct gguf_context * ctx_gguf;
struct ggml_context * ctx_data; struct ggml_context * ctx_data;
@ -533,7 +526,7 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
const int patch_size = hparams.patch_size; const int patch_size = hparams.patch_size;
const int num_patches = ((image_size / patch_size) * (image_size / patch_size)); const int num_patches = ((image_size / patch_size) * (image_size / patch_size));
const int num_patches_per_side = image_size / patch_size; GGML_UNUSED(num_patches_per_side); const int num_patches_per_side = image_size / patch_size; GGML_UNUSED(num_patches_per_side);
const int num_positions = num_patches + (ctx->has_class_embedding ? 1 : 0); const int num_positions = num_patches + 1;
const int hidden_size = hparams.hidden_size; const int hidden_size = hparams.hidden_size;
const int n_head = hparams.n_head; const int n_head = hparams.n_head;
const int d_head = hidden_size / n_head; const int d_head = hidden_size / n_head;
@ -564,23 +557,16 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
inp = ggml_reshape_3d(ctx0, inp, num_patches, hidden_size, batch_size); inp = ggml_reshape_3d(ctx0, inp, num_patches, hidden_size, batch_size);
inp = ggml_cont(ctx0, ggml_permute(ctx0, inp, 1, 0, 2, 3)); inp = ggml_cont(ctx0, ggml_permute(ctx0, inp, 1, 0, 2, 3));
if (ctx->has_patch_bias) {
// inp = ggml_add(ctx0, inp, ggml_repeat(ctx0, model.patch_bias, inp));
inp = ggml_add(ctx0, inp, model.patch_bias);
}
// concat class_embeddings and patch_embeddings // concat class_embeddings and patch_embeddings
struct ggml_tensor * embeddings = inp; struct ggml_tensor * embeddings = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, hidden_size, num_positions, batch_size);
if (ctx->has_class_embedding) {
embeddings = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, hidden_size, num_positions, batch_size);
embeddings = ggml_acc(ctx0, embeddings, model.class_embedding,
embeddings->nb[1], embeddings->nb[2], embeddings->nb[3], 0);
embeddings = ggml_acc(ctx0, embeddings, inp,
embeddings->nb[1], embeddings->nb[2], embeddings->nb[3], model.class_embedding->nb[1]);
}
ggml_set_name(embeddings, "embeddings"); ggml_set_name(embeddings, "embeddings");
ggml_set_input(embeddings); ggml_set_input(embeddings);
embeddings = ggml_acc(ctx0, embeddings, model.class_embedding,
embeddings->nb[1], embeddings->nb[2], embeddings->nb[3], 0);
embeddings = ggml_acc(ctx0, embeddings, inp,
embeddings->nb[1], embeddings->nb[2], embeddings->nb[3], model.class_embedding->nb[1]);
struct ggml_tensor * positions = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, num_positions); struct ggml_tensor * positions = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, num_positions);
ggml_set_name(positions, "positions"); ggml_set_name(positions, "positions");
@ -590,7 +576,7 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
ggml_add(ctx0, embeddings, ggml_get_rows(ctx0, model.position_embeddings, positions)); ggml_add(ctx0, embeddings, ggml_get_rows(ctx0, model.position_embeddings, positions));
// pre-layernorm // pre-layernorm
if (ctx->has_pre_norm) { {
embeddings = ggml_norm(ctx0, embeddings, eps); embeddings = ggml_norm(ctx0, embeddings, eps);
ggml_set_name(embeddings, "pre_ln"); ggml_set_name(embeddings, "pre_ln");
@ -678,14 +664,6 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
embeddings = cur; embeddings = cur;
} }
// post-layernorm
if (ctx->has_post_norm) {
embeddings = ggml_norm(ctx0, embeddings, eps);
ggml_set_name(embeddings, "post_ln");
embeddings = ggml_add(ctx0, ggml_mul(ctx0, embeddings, model.post_ln_w), model.post_ln_b);
}
// llava projector // llava projector
{ {
embeddings = ggml_reshape_2d(ctx0, embeddings, embeddings->ne[0], embeddings->ne[1]); embeddings = ggml_reshape_2d(ctx0, embeddings, embeddings->ne[0], embeddings->ne[1]);
@ -1170,39 +1148,12 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) {
} }
try {
vision_model.class_embedding = get_tensor(new_clip->ctx_data, TN_CLASS_EMBD);
new_clip->has_class_embedding = true;
} catch (const std::exception& e) {
new_clip->has_class_embedding = false;
}
try {
vision_model.pre_ln_w = get_tensor(new_clip->ctx_data, format(TN_LN_PRE, "v", "weight"));
vision_model.pre_ln_b = get_tensor(new_clip->ctx_data, format(TN_LN_PRE, "v", "bias"));
new_clip->has_pre_norm = true;
} catch (std::exception & e) {
new_clip->has_pre_norm = false;
}
try {
vision_model.post_ln_w = get_tensor(new_clip->ctx_data, format(TN_LN_POST, "v", "weight"));
vision_model.post_ln_b = get_tensor(new_clip->ctx_data, format(TN_LN_POST, "v", "bias"));
new_clip->has_post_norm = true;
} catch (std::exception & e) {
new_clip->has_post_norm = false;
}
try {
vision_model.patch_bias = get_tensor(new_clip->ctx_data, TN_PATCH_BIAS);
new_clip->has_patch_bias = true;
} catch (std::exception & e) {
new_clip->has_patch_bias = false;
}
try { try {
vision_model.patch_embeddings = get_tensor(new_clip->ctx_data, TN_PATCH_EMBD); vision_model.patch_embeddings = get_tensor(new_clip->ctx_data, TN_PATCH_EMBD);
vision_model.class_embedding = get_tensor(new_clip->ctx_data, TN_CLASS_EMBD);
vision_model.position_embeddings = get_tensor(new_clip->ctx_data, format(TN_POS_EMBD, "v")); vision_model.position_embeddings = get_tensor(new_clip->ctx_data, format(TN_POS_EMBD, "v"));
vision_model.pre_ln_w = get_tensor(new_clip->ctx_data, format(TN_LN_PRE, "v", "weight"));
vision_model.pre_ln_b = get_tensor(new_clip->ctx_data, format(TN_LN_PRE, "v", "bias"));
} catch(const std::exception& e) { } catch(const std::exception& e) {
LOG_TEE("%s: failed to load vision model tensors\n", __func__); LOG_TEE("%s: failed to load vision model tensors\n", __func__);
} }

Binary file not shown.

After

Width:  |  Height:  |  Size: 4 KiB

View file

@ -12,6 +12,8 @@
// increase max payload length to allow use of larger context size // increase max payload length to allow use of larger context size
#define CPPHTTPLIB_FORM_URL_ENCODED_PAYLOAD_MAX_LENGTH 1048576 #define CPPHTTPLIB_FORM_URL_ENCODED_PAYLOAD_MAX_LENGTH 1048576
#include "httplib.h" #include "httplib.h"
// Change JSON_ASSERT from assert() to GGML_ASSERT:
#define JSON_ASSERT GGML_ASSERT
#include "json.hpp" #include "json.hpp"
// auto generated files (update with ./deps.sh) // auto generated files (update with ./deps.sh)
@ -859,7 +861,7 @@ struct server_context {
slot.sparams.min_keep = json_value(data, "min_keep", default_sparams.min_keep); slot.sparams.min_keep = json_value(data, "min_keep", default_sparams.min_keep);
// process "json_schema" and "grammar" // process "json_schema" and "grammar"
if (data.contains("json_schema") && !data["json_schema"].is_null() && data.contains("grammar") && !data["grammar"].is_null()) { if (data.contains("json_schema") && !data.at("json_schema").is_null() && data.contains("grammar") && !data.at("grammar").is_null()) {
send_error(task, "Either \"json_schema\" or \"grammar\" can be specified, but not both", ERROR_TYPE_INVALID_REQUEST); send_error(task, "Either \"json_schema\" or \"grammar\" can be specified, but not both", ERROR_TYPE_INVALID_REQUEST);
return false; return false;
} else if (data.contains("json_schema") && !data.contains("grammar")) { } else if (data.contains("json_schema") && !data.contains("grammar")) {
@ -1512,7 +1514,7 @@ struct server_context {
// add subtasks // add subtasks
for (int i = 0; i < prompt_count; i++) { for (int i = 0; i < prompt_count; i++) {
json subtask_data = multiprompt_task.data; json subtask_data = multiprompt_task.data;
subtask_data["prompt"] = subtask_data["prompt"][i]; subtask_data["prompt"] = subtask_data.at("prompt")[i];
// subtasks inherit everything else (infill mode, embedding mode, etc.) // subtasks inherit everything else (infill mode, embedding mode, etc.)
request_completion(subtask_ids[i], id_multi, subtask_data, multiprompt_task.infill, multiprompt_task.embedding); request_completion(subtask_ids[i], id_multi, subtask_data, multiprompt_task.infill, multiprompt_task.embedding);
@ -1532,7 +1534,7 @@ struct server_context {
} }
if (task.data.contains("system_prompt")) { if (task.data.contains("system_prompt")) {
system_prompt_set(task.data["system_prompt"]); system_prompt_set(task.data.at("system_prompt"));
for (server_slot & slot : slots) { for (server_slot & slot : slots) {
slot.n_past = 0; slot.n_past = 0;
@ -1644,7 +1646,7 @@ struct server_context {
} break; } break;
case SERVER_TASK_TYPE_SLOT_SAVE: case SERVER_TASK_TYPE_SLOT_SAVE:
{ {
int id_slot = task.data["id_slot"]; int id_slot = task.data.at("id_slot");
server_slot * slot = get_slot(id_slot); server_slot * slot = get_slot(id_slot);
if (slot == nullptr) { if (slot == nullptr) {
send_error(task, "Invalid slot ID", ERROR_TYPE_INVALID_REQUEST); send_error(task, "Invalid slot ID", ERROR_TYPE_INVALID_REQUEST);
@ -1654,8 +1656,8 @@ struct server_context {
const size_t token_count = slot->cache_tokens.size(); const size_t token_count = slot->cache_tokens.size();
const int64_t t_start = ggml_time_us(); const int64_t t_start = ggml_time_us();
std::string filename = task.data["filename"]; std::string filename = task.data.at("filename");
std::string filepath = task.data["filepath"]; std::string filepath = task.data.at("filepath");
const size_t nwrite = llama_state_seq_save_file(ctx, filepath.c_str(), slot->id + 1, slot->cache_tokens.data(), token_count); const size_t nwrite = llama_state_seq_save_file(ctx, filepath.c_str(), slot->id + 1, slot->cache_tokens.data(), token_count);
@ -1679,7 +1681,7 @@ struct server_context {
} break; } break;
case SERVER_TASK_TYPE_SLOT_RESTORE: case SERVER_TASK_TYPE_SLOT_RESTORE:
{ {
int id_slot = task.data["id_slot"]; int id_slot = task.data.at("id_slot");
server_slot * slot = get_slot(id_slot); server_slot * slot = get_slot(id_slot);
if (slot == nullptr) { if (slot == nullptr) {
send_error(task, "Invalid slot ID", ERROR_TYPE_INVALID_REQUEST); send_error(task, "Invalid slot ID", ERROR_TYPE_INVALID_REQUEST);
@ -1688,8 +1690,8 @@ struct server_context {
const int64_t t_start = ggml_time_us(); const int64_t t_start = ggml_time_us();
std::string filename = task.data["filename"]; std::string filename = task.data.at("filename");
std::string filepath = task.data["filepath"]; std::string filepath = task.data.at("filepath");
slot->cache_tokens.resize(slot->n_ctx); slot->cache_tokens.resize(slot->n_ctx);
size_t token_count = 0; size_t token_count = 0;
@ -1721,7 +1723,7 @@ struct server_context {
} break; } break;
case SERVER_TASK_TYPE_SLOT_ERASE: case SERVER_TASK_TYPE_SLOT_ERASE:
{ {
int id_slot = task.data["id_slot"]; int id_slot = task.data.at("id_slot");
server_slot * slot = get_slot(id_slot); server_slot * slot = get_slot(id_slot);
if (slot == nullptr) { if (slot == nullptr) {
send_error(task, "Invalid slot ID", ERROR_TYPE_INVALID_REQUEST); send_error(task, "Invalid slot ID", ERROR_TYPE_INVALID_REQUEST);
@ -3136,8 +3138,8 @@ int main(int argc, char ** argv) {
server_task_result result = ctx_server.queue_results.recv(task.id); server_task_result result = ctx_server.queue_results.recv(task.id);
ctx_server.queue_results.remove_waiting_task_id(task.id); ctx_server.queue_results.remove_waiting_task_id(task.id);
const int n_idle_slots = result.data["idle"]; const int n_idle_slots = result.data.at("idle");
const int n_processing_slots = result.data["processing"]; const int n_processing_slots = result.data.at("processing");
json health = { json health = {
{"status", "ok"}, {"status", "ok"},
@ -3147,7 +3149,7 @@ int main(int argc, char ** argv) {
res.status = 200; // HTTP OK res.status = 200; // HTTP OK
if (sparams.slots_endpoint && req.has_param("include_slots")) { if (sparams.slots_endpoint && req.has_param("include_slots")) {
health["slots"] = result.data["slots"]; health["slots"] = result.data.at("slots");
} }
if (n_idle_slots == 0) { if (n_idle_slots == 0) {
@ -3191,7 +3193,7 @@ int main(int argc, char ** argv) {
server_task_result result = ctx_server.queue_results.recv(task.id); server_task_result result = ctx_server.queue_results.recv(task.id);
ctx_server.queue_results.remove_waiting_task_id(task.id); ctx_server.queue_results.remove_waiting_task_id(task.id);
res.set_content(result.data["slots"].dump(), "application/json"); res.set_content(result.data.at("slots").dump(), "application/json");
res.status = 200; // HTTP OK res.status = 200; // HTTP OK
}; };
@ -3218,32 +3220,32 @@ int main(int argc, char ** argv) {
json data = result.data; json data = result.data;
const uint64_t n_prompt_tokens_processed = data["n_prompt_tokens_processed"]; const uint64_t n_prompt_tokens_processed = data.at("n_prompt_tokens_processed");
const uint64_t t_prompt_processing = data["t_prompt_processing"]; const uint64_t t_prompt_processing = data.at("t_prompt_processing");
const uint64_t n_tokens_predicted = data["n_tokens_predicted"]; const uint64_t n_tokens_predicted = data.at("n_tokens_predicted");
const uint64_t t_tokens_generation = data["t_tokens_generation"]; const uint64_t t_tokens_generation = data.at("t_tokens_generation");
const int32_t kv_cache_used_cells = data["kv_cache_used_cells"]; const int32_t kv_cache_used_cells = data.at("kv_cache_used_cells");
// metrics definition: https://prometheus.io/docs/practices/naming/#metric-names // metrics definition: https://prometheus.io/docs/practices/naming/#metric-names
json all_metrics_def = json { json all_metrics_def = json {
{"counter", {{ {"counter", {{
{"name", "prompt_tokens_total"}, {"name", "prompt_tokens_total"},
{"help", "Number of prompt tokens processed."}, {"help", "Number of prompt tokens processed."},
{"value", (uint64_t) data["n_prompt_tokens_processed_total"]} {"value", (uint64_t) data.at("n_prompt_tokens_processed_total")}
}, { }, {
{"name", "prompt_seconds_total"}, {"name", "prompt_seconds_total"},
{"help", "Prompt process time"}, {"help", "Prompt process time"},
{"value", (uint64_t) data["t_prompt_processing_total"] / 1.e3} {"value", (uint64_t) data.at("t_prompt_processing_total") / 1.e3}
}, { }, {
{"name", "tokens_predicted_total"}, {"name", "tokens_predicted_total"},
{"help", "Number of generation tokens processed."}, {"help", "Number of generation tokens processed."},
{"value", (uint64_t) data["n_tokens_predicted_total"]} {"value", (uint64_t) data.at("n_tokens_predicted_total")}
}, { }, {
{"name", "tokens_predicted_seconds_total"}, {"name", "tokens_predicted_seconds_total"},
{"help", "Predict process time"}, {"help", "Predict process time"},
{"value", (uint64_t) data["t_tokens_generation_total"] / 1.e3} {"value", (uint64_t) data.at("t_tokens_generation_total") / 1.e3}
}}}, }}},
{"gauge", {{ {"gauge", {{
{"name", "prompt_tokens_seconds"}, {"name", "prompt_tokens_seconds"},
@ -3260,15 +3262,15 @@ int main(int argc, char ** argv) {
},{ },{
{"name", "kv_cache_tokens"}, {"name", "kv_cache_tokens"},
{"help", "KV-cache tokens."}, {"help", "KV-cache tokens."},
{"value", (uint64_t) data["kv_cache_tokens_count"]} {"value", (uint64_t) data.at("kv_cache_tokens_count")}
},{ },{
{"name", "requests_processing"}, {"name", "requests_processing"},
{"help", "Number of request processing."}, {"help", "Number of request processing."},
{"value", (uint64_t) data["processing"]} {"value", (uint64_t) data.at("processing")}
},{ },{
{"name", "requests_deferred"}, {"name", "requests_deferred"},
{"help", "Number of request deferred."}, {"help", "Number of request deferred."},
{"value", (uint64_t) data["deferred"]} {"value", (uint64_t) data.at("deferred")}
}}} }}}
}; };
@ -3279,8 +3281,8 @@ int main(int argc, char ** argv) {
const auto & metrics_def = el.value(); const auto & metrics_def = el.value();
for (const auto & metric_def : metrics_def) { for (const auto & metric_def : metrics_def) {
const std::string name = metric_def["name"]; const std::string name = metric_def.at("name");
const std::string help = metric_def["help"]; const std::string help = metric_def.at("help");
auto value = json_value(metric_def, "value", 0.); auto value = json_value(metric_def, "value", 0.);
prometheus << "# HELP llamacpp:" << name << " " << help << "\n" prometheus << "# HELP llamacpp:" << name << " " << help << "\n"
@ -3289,7 +3291,7 @@ int main(int argc, char ** argv) {
} }
} }
const int64_t t_start = data["t_start"]; const int64_t t_start = data.at("t_start");
res.set_header("Process-Start-Time-Unix", std::to_string(t_start)); res.set_header("Process-Start-Time-Unix", std::to_string(t_start));
res.set_content(prometheus.str(), "text/plain; version=0.0.4"); res.set_content(prometheus.str(), "text/plain; version=0.0.4");
@ -3298,7 +3300,7 @@ int main(int argc, char ** argv) {
const auto handle_slots_save = [&ctx_server, &res_error, &sparams](const httplib::Request & req, httplib::Response & res, int id_slot) { const auto handle_slots_save = [&ctx_server, &res_error, &sparams](const httplib::Request & req, httplib::Response & res, int id_slot) {
json request_data = json::parse(req.body); json request_data = json::parse(req.body);
std::string filename = request_data["filename"]; std::string filename = request_data.at("filename");
if (!validate_file_name(filename)) { if (!validate_file_name(filename)) {
res_error(res, format_error_response("Invalid filename", ERROR_TYPE_INVALID_REQUEST)); res_error(res, format_error_response("Invalid filename", ERROR_TYPE_INVALID_REQUEST));
return; return;
@ -3328,7 +3330,7 @@ int main(int argc, char ** argv) {
const auto handle_slots_restore = [&ctx_server, &res_error, &sparams](const httplib::Request & req, httplib::Response & res, int id_slot) { const auto handle_slots_restore = [&ctx_server, &res_error, &sparams](const httplib::Request & req, httplib::Response & res, int id_slot) {
json request_data = json::parse(req.body); json request_data = json::parse(req.body);
std::string filename = request_data["filename"]; std::string filename = request_data.at("filename");
if (!validate_file_name(filename)) { if (!validate_file_name(filename)) {
res_error(res, format_error_response("Invalid filename", ERROR_TYPE_INVALID_REQUEST)); res_error(res, format_error_response("Invalid filename", ERROR_TYPE_INVALID_REQUEST));
return; return;
@ -3648,7 +3650,7 @@ int main(int argc, char ** argv) {
std::vector<llama_token> tokens; std::vector<llama_token> tokens;
if (body.count("content") != 0) { if (body.count("content") != 0) {
const bool add_special = json_value(body, "add_special", false); const bool add_special = json_value(body, "add_special", false);
tokens = ctx_server.tokenize(body["content"], add_special); tokens = ctx_server.tokenize(body.at("content"), add_special);
} }
const json data = format_tokenizer_response(tokens); const json data = format_tokenizer_response(tokens);
return res.set_content(data.dump(), "application/json; charset=utf-8"); return res.set_content(data.dump(), "application/json; charset=utf-8");
@ -3660,7 +3662,7 @@ int main(int argc, char ** argv) {
std::string content; std::string content;
if (body.count("tokens") != 0) { if (body.count("tokens") != 0) {
const std::vector<llama_token> tokens = body["tokens"]; const std::vector<llama_token> tokens = body.at("tokens");
content = tokens_to_str(ctx_server.ctx, tokens.cbegin(), tokens.cend()); content = tokens_to_str(ctx_server.ctx, tokens.cbegin(), tokens.cend());
} }
@ -3683,10 +3685,10 @@ int main(int argc, char ** argv) {
json prompt; json prompt;
if (body.count("input") != 0) { if (body.count("input") != 0) {
is_openai = true; is_openai = true;
prompt = body["input"]; prompt = body.at("input");
} else if (body.count("content") != 0) { } else if (body.count("content") != 0) {
// with "content", we only support single prompt // with "content", we only support single prompt
prompt = std::vector<std::string>{body["content"]}; prompt = std::vector<std::string>{body.at("content")};
} else { } else {
res_error(res, format_error_response("\"input\" or \"content\" must be provided", ERROR_TYPE_INVALID_REQUEST)); res_error(res, format_error_response("\"input\" or \"content\" must be provided", ERROR_TYPE_INVALID_REQUEST));
return; return;
@ -3705,7 +3707,7 @@ int main(int argc, char ** argv) {
if (!result.error) { if (!result.error) {
if (result.data.count("results")) { if (result.data.count("results")) {
// result for multi-task // result for multi-task
responses = result.data["results"]; responses = result.data.at("results");
} else { } else {
// result for single task // result for single task
responses = std::vector<json>{result.data}; responses = std::vector<json>{result.data};

View file

@ -939,7 +939,7 @@ async def oai_chat_completions(user_prompt,
while event_received: while event_received:
event_received = False event_received = False
async for line_in_bytes in response.content: async for line_in_bytes in response.content:
line = line_in_bytes.decode('utf8') line = line_in_bytes.decode('utf-8')
line = line.rstrip('\n').rstrip('\r') line = line.rstrip('\n').rstrip('\r')
if line == '': if line == '':
continue continue

View file

@ -0,0 +1,5 @@
# LLaMA.cpp Server Wild Theme
Simple themes directory of sample "public" directories. To try any of these add --path to your run like `server --path=wild`.
![image](wild/wild.png)

View file

@ -0,0 +1,7 @@
# LLaMA.cpp Server Buttons Top Theme
Simple tweaks to the UI. Chat buttons at the top of the page instead of bottom so you can hit Stop instead of chasing it down the page.
To use simply run server with `--path=themes/buttons_top`
![image](buttons_top.png)

Binary file not shown.

After

Width:  |  Height:  |  Size: 117 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 4 KiB

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,5 @@
# LLaMA.cpp Server Wild Theme
Simple tweaks to the UI. To use simply run server with `--path=themes/wild`
![image](wild.png)

Binary file not shown.

After

Width:  |  Height:  |  Size: 4 KiB

File diff suppressed because it is too large Load diff

Binary file not shown.

After

Width:  |  Height:  |  Size: 75 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 254 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 485 KiB

View file

@ -3,6 +3,8 @@
#include "llama.h" #include "llama.h"
#include "common.h" #include "common.h"
// Change JSON_ASSERT from assert() to GGML_ASSERT:
#define JSON_ASSERT GGML_ASSERT
#include "json.hpp" #include "json.hpp"
#include <string> #include <string>
@ -373,11 +375,11 @@ static json oaicompat_completion_params_parse(
llama_params["top_p"] = json_value(body, "top_p", 1.0); llama_params["top_p"] = json_value(body, "top_p", 1.0);
// Apply chat template to the list of messages // Apply chat template to the list of messages
llama_params["prompt"] = format_chat(model, chat_template, body["messages"]); llama_params["prompt"] = format_chat(model, chat_template, body.at("messages"));
// Handle "stop" field // Handle "stop" field
if (body.contains("stop") && body["stop"].is_string()) { if (body.contains("stop") && body.at("stop").is_string()) {
llama_params["stop"] = json::array({body["stop"].get<std::string>()}); llama_params["stop"] = json::array({body.at("stop").get<std::string>()});
} else { } else {
llama_params["stop"] = json_value(body, "stop", json::array()); llama_params["stop"] = json_value(body, "stop", json::array());
} }

View file

@ -1647,7 +1647,7 @@ static void ggml_cuda_op_mul_mat(
} }
} }
static void ggml_cuda_mul_mat_vec_p021(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst){ static void ggml_cuda_mul_mat_vec_p021(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
GGML_ASSERT(ggml_is_permuted(src0) && ggml_is_permuted(src1)); GGML_ASSERT(ggml_is_permuted(src0) && ggml_is_permuted(src1));
GGML_ASSERT(ggml_backend_buffer_is_cuda(src0->buffer)); GGML_ASSERT(ggml_backend_buffer_is_cuda(src0->buffer));
GGML_ASSERT(src0->nb[0] <= src0->nb[1] && src0->nb[2] <= src0->nb[3]); // 0213 permutation GGML_ASSERT(src0->nb[0] <= src0->nb[1] && src0->nb[2] <= src0->nb[3]); // 0213 permutation
@ -1670,7 +1670,7 @@ static void ggml_cuda_mul_mat_vec_p021(ggml_backend_cuda_context & ctx, const gg
ggml_mul_mat_p021_f16_f32_cuda(src0_ddq, src1_ddf, dst_ddf, ne00, ne01, ne02, ne12, main_stream); ggml_mul_mat_p021_f16_f32_cuda(src0_ddq, src1_ddf, dst_ddf, ne00, ne01, ne02, ne12, main_stream);
} }
static void ggml_cuda_mul_mat_vec_nc(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst){ static void ggml_cuda_mul_mat_vec_nc(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
GGML_ASSERT(!ggml_is_transposed(src0)); GGML_ASSERT(!ggml_is_transposed(src0));
GGML_ASSERT(!ggml_is_transposed(src1)); GGML_ASSERT(!ggml_is_transposed(src1));
GGML_ASSERT(!ggml_is_permuted(src0)); GGML_ASSERT(!ggml_is_permuted(src0));
@ -2410,32 +2410,304 @@ GGML_CALL static void ggml_backend_cuda_synchronize(ggml_backend_t backend) {
GGML_UNUSED(backend); GGML_UNUSED(backend);
} }
static void set_ggml_graph_node_properties(ggml_tensor * node, ggml_graph_node_properties * graph_node_properties) {
graph_node_properties->node_address = node->data;
graph_node_properties->node_op = node->op;
for (int i = 0; i < GGML_MAX_DIMS; i++) {
graph_node_properties->ne[i] = node->ne[i];
graph_node_properties->nb[i] = node->nb[i];
}
for (int i = 0; i < GGML_MAX_SRC; i++) {
graph_node_properties->src_address[i] = node->src[i] ? node->src[i]->data : nullptr;
}
}
static bool ggml_graph_node_has_matching_properties(ggml_tensor * node, ggml_graph_node_properties * graph_node_properties) {
if (node->data != graph_node_properties->node_address &&
node->op != GGML_OP_CPY &&
node->op != GGML_OP_VIEW) {
return false;
}
if (node->op != graph_node_properties->node_op) {
return false;
}
for (int i = 0; i < GGML_MAX_DIMS; i++) {
if (node->ne[i] != graph_node_properties->ne[i]) {
return false;
}
if (node->nb[i] != graph_node_properties->nb[i]) {
return false;
}
}
for (int i = 0; i < GGML_MAX_SRC; i++) {
if (node->src[i] &&
node->src[i]->data != graph_node_properties->src_address[i] &&
node->op != GGML_OP_CPY &&
node->op != GGML_OP_VIEW
) {
return false;
}
}
return true;
}
GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) { GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context; ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context;
ggml_cuda_set_device(cuda_ctx->device); ggml_cuda_set_device(cuda_ctx->device);
for (int i = 0; i < cgraph->n_nodes; i++) { #ifdef USE_CUDA_GRAPH
ggml_tensor * node = cgraph->nodes[i]; static const bool disable_cuda_graphs_due_to_env = (getenv("GGML_CUDA_DISABLE_GRAPHS") != nullptr);
if (ggml_is_empty(node) || node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE || node->op == GGML_OP_NONE) { // Objects required for CUDA Graph
continue; if (cuda_ctx->cuda_graph == nullptr) {
cuda_ctx->cuda_graph.reset(new ggml_cuda_graph());
}
bool use_cuda_graph = true;
bool cuda_graph_update_required = false;
// pointer to CUDA cpy kernel, which is required to identify
// kernel parameters which need updated in the graph for each token
void * ggml_cuda_cpy_fn_ptr = nullptr;
if (cuda_ctx->cuda_graph->graph == nullptr) {
if (ggml_cuda_info().devices[cuda_ctx->device].cc < CC_AMPERE) {
cuda_ctx->cuda_graph->disable_due_to_gpu_arch = true;
#ifndef NDEBUG
fprintf(stderr, "%s: disabling CUDA graphs due to GPU architecture\n", __func__);
#endif
}
}
// Disable CUDA graphs in presence of env var, old GPU, use-case which is changing too rapidly,
// or previous graph capture failure.
// Also disable for multi-gpu for now. TO DO investigate
if (disable_cuda_graphs_due_to_env
|| cuda_ctx->cuda_graph->disable_due_to_gpu_arch
|| cuda_ctx->cuda_graph->disable_due_to_too_many_updates
|| cuda_ctx->cuda_graph->disable_due_to_failed_graph_capture) {
use_cuda_graph = false;
}
if (use_cuda_graph) {
if (cuda_ctx->cuda_graph->instance == nullptr) {
cuda_graph_update_required = true;
} }
// Check if the graph size has changed
if (cuda_ctx->cuda_graph->ggml_graph_properties.size() != (size_t)cgraph->n_nodes) {
cuda_graph_update_required = true;
cuda_ctx->cuda_graph->ggml_graph_properties.resize(cgraph->n_nodes);
}
// Loop over nodes in GGML graph to determine if CUDA graph update is required
// and store properties to allow this comparison for the next token
for (int i = 0; i < cgraph->n_nodes; i++) {
bool has_matching_properties = true;
if (!cuda_graph_update_required) {
has_matching_properties = ggml_graph_node_has_matching_properties(cgraph->nodes[i], &cuda_ctx->cuda_graph->ggml_graph_properties[i]);
}
if (!has_matching_properties) {
cuda_graph_update_required = true;
}
set_ggml_graph_node_properties(cgraph->nodes[i], &cuda_ctx->cuda_graph->ggml_graph_properties[i]);
}
// Loop over nodes in GGML graph to obtain info needed for CUDA graph
cuda_ctx->cuda_graph->updated_kernel_arg.clear();
for (int i = 0; i < cgraph->n_nodes; i++) {
ggml_tensor * node = cgraph->nodes[i];
if (node->src[0] && ggml_backend_buffer_is_cuda_split(node->src[0]->buffer)) {
use_cuda_graph = false; // Split buffers are not supported by CUDA graph capture
#ifndef NDEBUG #ifndef NDEBUG
assert(node->buffer->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device)); fprintf(stderr, "%s: disabling CUDA graphs due to split buffer\n", __func__);
for (int j = 0; j < GGML_MAX_SRC; j++) { #endif
if (node->src[j] != nullptr) { }
assert(node->src[j]->buffer->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device) || ggml_backend_buffer_is_cuda_split(node->src[j]->buffer));
if (node->op == GGML_OP_MUL_MAT_ID) {
use_cuda_graph = false; // This node type is not supported by CUDA graph capture
#ifndef NDEBUG
fprintf(stderr, "%s: disabling CUDA graphs due to mul_mat_id\n", __func__);
#endif
}
if (node->op == GGML_OP_ADD && node->src[1] && node->src[1]->ne[1] > 1) {
// disable CUDA graphs for batch size > 1 for now.
// Changes in batch size or context size can cause changes to the grid size of some kernels.
use_cuda_graph = false;
#ifndef NDEBUG
fprintf(stderr, "%s: disabling CUDA graphs due to batch size > 1 [%s] [%ld %ld %ld %ld]\n", __func__, node->name, node->ne[0], node->ne[1], node->ne[2], node->ne[3]);
#endif
}
if (node->op == GGML_OP_CPY) {
// store the copy op parameter which changes with each token.
cuda_ctx->cuda_graph->updated_kernel_arg.push_back((char **) &(node->src[1]->data));
if (ggml_cuda_cpy_fn_ptr == nullptr) {
// store a pointer to the copy op CUDA kernel to identify it later
ggml_cuda_cpy_fn_ptr = ggml_cuda_cpy_fn(node->src[0], node->src[1]);
}
}
if (!use_cuda_graph) {
break;
} }
} }
// Disable CUDA graphs (from the next token) if the use-case is demanding too many consecutive graph updates.
if (cuda_graph_update_required) {
cuda_ctx->cuda_graph->number_consecutive_updates++;
} else {
cuda_ctx->cuda_graph->number_consecutive_updates = 0;
}
if (cuda_ctx->cuda_graph->number_consecutive_updates >= 4) {
cuda_ctx->cuda_graph->disable_due_to_too_many_updates = true;
#ifndef NDEBUG
fprintf(stderr, "%s: disabling CUDA graphs due to too many consecutive updates\n", __func__);
#endif
}
}
if (use_cuda_graph && cuda_graph_update_required) { // Start CUDA graph capture
CUDA_CHECK(cudaStreamBeginCapture(cuda_ctx->stream(), cudaStreamCaptureModeRelaxed));
}
#else
bool use_cuda_graph = false;
bool cuda_graph_update_required = false;
#endif // USE_CUDA_GRAPH
bool graph_evaluated_or_captured = false;
while (!graph_evaluated_or_captured) {
// Only perform the graph execution if CUDA graphs are not enabled, or we are capturing the graph.
// With the use of CUDA graphs, the execution will be performed by the graph launch.
if (!use_cuda_graph || cuda_graph_update_required) {
for (int i = 0; i < cgraph->n_nodes; i++) {
ggml_tensor * node = cgraph->nodes[i];
if (ggml_is_empty(node) || node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE || node->op == GGML_OP_NONE) {
continue;
}
#ifndef NDEBUG
assert(node->buffer->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device));
for (int j = 0; j < GGML_MAX_SRC; j++) {
if (node->src[j] != nullptr) {
assert(node->src[j]->buffer->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device) || ggml_backend_buffer_is_cuda_split(node->src[j]->buffer));
}
}
#endif #endif
bool ok = ggml_cuda_compute_forward(*cuda_ctx, node); bool ok = ggml_cuda_compute_forward(*cuda_ctx, node);
if (!ok) { if (!ok) {
fprintf(stderr, "%s: error: op not supported %s (%s)\n", __func__, node->name, ggml_op_name(node->op)); fprintf(stderr, "%s: error: op not supported %s (%s)\n", __func__, node->name, ggml_op_name(node->op));
}
GGML_ASSERT(ok);
}
} }
GGML_ASSERT(ok);
#ifdef USE_CUDA_GRAPH
if (use_cuda_graph && cuda_graph_update_required) { // End CUDA graph capture
if (cuda_ctx->cuda_graph->graph != nullptr) {
CUDA_CHECK(cudaGraphDestroy(cuda_ctx->cuda_graph->graph));
cuda_ctx->cuda_graph->graph = nullptr;
}
CUDA_CHECK(cudaStreamEndCapture(cuda_ctx->stream(), &cuda_ctx->cuda_graph->graph));
#if 0
if (disable_cuda_graphs_due_to_failed_capture) {
use_cuda_graph = false;
cuda_ctx->cuda_graph->disable_due_to_failed_graph_capture = true;
#ifndef NDEBUG
fprintf(stderr, "%s: disabling CUDA graphs due to failed graph capture\n", __func__);
#endif
} else {
graph_evaluated_or_captured = true; // CUDA graph has been captured
}
#endif
graph_evaluated_or_captured = true; // CUDA graph has been captured
} else {
graph_evaluated_or_captured = true; // ggml graph has been directly evaluated
}
}
if (use_cuda_graph) {
if (cuda_ctx->cuda_graph->instance == nullptr) { // Create executable graph from captured graph.
CUDA_CHECK(cudaGraphInstantiate(&cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, NULL, NULL, 0));
}
// Perform update to graph (if required for this token), and change copy parameter (required for every token)
if (cuda_graph_update_required) {
// Extract nodes from graph
if (cuda_ctx->cuda_graph->num_nodes == 0) {
// First call with null argument gets number of nodes in graph
CUDA_CHECK(cudaGraphGetNodes(cuda_ctx->cuda_graph->graph, nullptr, &cuda_ctx->cuda_graph->num_nodes));
}
// Subsequent call with non-null argument gets nodes
cuda_ctx->cuda_graph->nodes.resize(cuda_ctx->cuda_graph->num_nodes);
cuda_ctx->cuda_graph->params.resize(cuda_ctx->cuda_graph->num_nodes);
if (cuda_ctx->cuda_graph->num_nodes > 0) {
CUDA_CHECK(cudaGraphGetNodes(cuda_ctx->cuda_graph->graph, cuda_ctx->cuda_graph->nodes.data(), &cuda_ctx->cuda_graph->num_nodes));
// Loop over nodes, and extract kernel parameters from each node
for (size_t i = 0; i < cuda_ctx->cuda_graph->num_nodes; i++) {
cudaGraphNodeType node_type;
CUDA_CHECK(cudaGraphNodeGetType(cuda_ctx->cuda_graph->nodes[i], &node_type));
if (node_type == cudaGraphNodeTypeKernel) {
cudaError_t stat = cudaGraphKernelNodeGetParams(cuda_ctx->cuda_graph->nodes[i], &cuda_ctx->cuda_graph->params[i]); // Get params using runtime
if (stat == cudaErrorInvalidDeviceFunction) {
// Fails due to incorrect handling by CUDA runtime of CUDA BLAS node.
// We don't need to update blas nodes, so clear error and move on.
cudaGetLastError();
} else {
GGML_ASSERT(stat == cudaSuccess);
}
}
}
}
}
// One of the arguments to the copy kernel is updated for each token, hence we need to
// replace that argument with the updated value in the CUDA graph
if (!cuda_graph_update_required) { // on update steps, the live parameters will already be captured
int k = 0;
for (size_t i = 0; i < cuda_ctx->cuda_graph->num_nodes; i++) {
if (cuda_ctx->cuda_graph->params[i].func == ggml_cuda_cpy_fn_ptr) {
char ** updated_kernel_arg_ptr = cuda_ctx->cuda_graph->updated_kernel_arg.at(k++);
cuda_ctx->cuda_graph->params[i].kernelParams[1] = updated_kernel_arg_ptr;
CUDA_CHECK(cudaGraphKernelNodeSetParams(cuda_ctx->cuda_graph->nodes[i], &cuda_ctx->cuda_graph->params[i]));
}
}
}
// Update graph executable
cudaGraphExecUpdateResultInfo result_info;
cudaError_t stat = cudaGraphExecUpdate(cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, &result_info);
if (stat == cudaErrorGraphExecUpdateFailure) {
#ifndef NDEBUG
fprintf(stderr, "%s: CUDA graph update failed\n", __func__);
#endif
// The pre-existing graph exec cannot be updated due to violated constraints
// so instead clear error and re-instantiate
cudaGetLastError();
CUDA_CHECK(cudaGraphExecDestroy(cuda_ctx->cuda_graph->instance));
cuda_ctx->cuda_graph->instance = nullptr;
CUDA_CHECK(cudaGraphInstantiate(&cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, NULL, NULL, 0));
} else {
GGML_ASSERT(stat == cudaSuccess);
}
// Launch graph
CUDA_CHECK(cudaGraphLaunch(cuda_ctx->cuda_graph->instance, cuda_ctx->stream()));
#else
graph_evaluated_or_captured = true;
#endif // USE_CUDA_GRAPH
} }
return GGML_STATUS_SUCCESS; return GGML_STATUS_SUCCESS;

View file

@ -31,5 +31,4 @@ void ggml_cuda_op_clamp(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
memcpy(&max, (float *) dst->op_params + 1, sizeof(float)); memcpy(&max, (float *) dst->op_params + 1, sizeof(float));
clamp_f32_cuda(src0_d, dst_d, min, max, ggml_nelements(src0), stream); clamp_f32_cuda(src0_d, dst_d, min, max, ggml_nelements(src0), stream);
CUDA_CHECK(cudaGetLastError());
} }

View file

@ -19,6 +19,7 @@
#include <cassert> #include <cassert>
#include <cfloat> #include <cfloat>
#include <string> #include <string>
#include <vector>
#if defined(GGML_USE_HIPBLAS) #if defined(GGML_USE_HIPBLAS)
#include <hip/hip_runtime.h> #include <hip/hip_runtime.h>
@ -526,6 +527,43 @@ struct ggml_tensor_extra_gpu {
cudaEvent_t events[GGML_CUDA_MAX_DEVICES][GGML_CUDA_MAX_STREAMS]; // events for synchronizing multiple GPUs cudaEvent_t events[GGML_CUDA_MAX_DEVICES][GGML_CUDA_MAX_STREAMS]; // events for synchronizing multiple GPUs
}; };
#if (CUDART_VERSION >= 12000) && defined(GGML_CUDA_USE_GRAPHS)
#define USE_CUDA_GRAPH
#endif
struct ggml_graph_node_properties {
void * node_address;
ggml_op node_op;
int64_t ne[GGML_MAX_DIMS];
size_t nb[GGML_MAX_DIMS];
void * src_address[GGML_MAX_SRC];
};
struct ggml_cuda_graph {
#ifdef USE_CUDA_GRAPH
~ggml_cuda_graph() {
if (instance != nullptr) {
CUDA_CHECK(cudaGraphExecDestroy(instance));
}
if (graph != nullptr) {
CUDA_CHECK(cudaGraphDestroy(graph));
}
}
cudaGraph_t graph = nullptr;
cudaGraphExec_t instance = nullptr;
size_t num_nodes = 0;
std::vector<cudaGraphNode_t> nodes;
std::vector<cudaKernelNodeParams> params;
bool disable_due_to_gpu_arch = false;
bool disable_due_to_too_many_updates = false;
bool disable_due_to_failed_graph_capture = false;
int number_consecutive_updates = 0;
std::vector<ggml_graph_node_properties> ggml_graph_properties;
std::vector<char **> updated_kernel_arg;
#endif
};
struct ggml_backend_cuda_context { struct ggml_backend_cuda_context {
int device; int device;
std::string name; std::string name;
@ -534,6 +572,8 @@ struct ggml_backend_cuda_context {
cudaStream_t streams[GGML_CUDA_MAX_DEVICES][GGML_CUDA_MAX_STREAMS] = { { nullptr } }; cudaStream_t streams[GGML_CUDA_MAX_DEVICES][GGML_CUDA_MAX_STREAMS] = { { nullptr } };
cublasHandle_t cublas_handles[GGML_CUDA_MAX_DEVICES] = {nullptr}; cublasHandle_t cublas_handles[GGML_CUDA_MAX_DEVICES] = {nullptr};
std::unique_ptr<ggml_cuda_graph> cuda_graph;
explicit ggml_backend_cuda_context(int device) : explicit ggml_backend_cuda_context(int device) :
device(device), device(device),
name(GGML_CUDA_NAME + std::to_string(device)) { name(GGML_CUDA_NAME + std::to_string(device)) {

View file

@ -727,7 +727,6 @@ static void convert_unary_cuda(const void * __restrict__ vx, dst_t * __restrict_
} }
to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) { to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) {
int id;
switch (type) { switch (type) {
case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_0:
return dequantize_row_q4_0_cuda; return dequantize_row_q4_0_cuda;
@ -738,8 +737,7 @@ to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) {
case GGML_TYPE_Q5_1: case GGML_TYPE_Q5_1:
return dequantize_block_cuda<QK5_1, QR5_1, dequantize_q5_1>; return dequantize_block_cuda<QK5_1, QR5_1, dequantize_q5_1>;
case GGML_TYPE_Q8_0: case GGML_TYPE_Q8_0:
CUDA_CHECK(cudaGetDevice(&id)); if (ggml_cuda_info().devices[ggml_cuda_get_device()].cc >= CC_PASCAL) {
if (ggml_cuda_info().devices[id].cc >= CC_PASCAL) {
return dequantize_block_q8_0_f16_cuda; return dequantize_block_q8_0_f16_cuda;
} }
return dequantize_block_cuda<QK8_0, QR8_0, dequantize_q8_0>; return dequantize_block_cuda<QK8_0, QR8_0, dequantize_q8_0>;

View file

@ -459,3 +459,32 @@ void ggml_cuda_dup(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0]; const ggml_tensor * src0 = dst->src[0];
ggml_cuda_cpy(ctx, src0, dst); ggml_cuda_cpy(ctx, src0, dst);
} }
void* ggml_cuda_cpy_fn(const ggml_tensor * src0, ggml_tensor * src1) {
if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) {
return (void*) cpy_f32_f16<cpy_1_f32_f32>;
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16) {
return (void*) cpy_f32_f16<cpy_1_f32_f16>;
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q8_0) {
return (void*) cpy_f32_q<cpy_blck_f32_q8_0, QK8_0>;
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_0) {
return (void*) cpy_f32_q<cpy_blck_f32_q4_0, QK4_0>;
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_1) {
return (void*) cpy_f32_q<cpy_blck_f32_q4_1, QK4_1>;
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_0) {
return (void*) cpy_f32_q<cpy_blck_f32_q5_0, QK5_0>;
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_IQ4_NL) {
return (void*) cpy_f32_q<cpy_blck_f32_iq4_nl, QK4_NL>;
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_1) {
return (void*) cpy_f32_q<cpy_blck_f32_q5_1, QK5_1>;
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) {
return (void*) cpy_f32_f16<cpy_1_f32_f16>;
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) {
return (void*) cpy_f32_f16<cpy_1_f16_f32>;
} else {
fprintf(stderr, "%s: unsupported type combination (%s to %s)\n", __func__,
ggml_type_name(src0->type), ggml_type_name(src1->type));
GGML_ASSERT(false);
}
}

View file

@ -5,3 +5,5 @@
void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, ggml_tensor * src1); void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, ggml_tensor * src1);
void ggml_cuda_dup(ggml_backend_cuda_context & ctx, ggml_tensor * dst); void ggml_cuda_dup(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void* ggml_cuda_cpy_fn(const ggml_tensor * src0, ggml_tensor * src1);

View file

@ -1735,8 +1735,7 @@ static void ggml_mul_mat_q4_0_q8_1_cuda(
const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x, const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x,
const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) { const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) {
int id; int id = ggml_cuda_get_device();
CUDA_CHECK(cudaGetDevice(&id));
const int compute_capability = ggml_cuda_info().devices[id].cc; const int compute_capability = ggml_cuda_info().devices[id].cc;
int mmq_x, mmq_y, nwarps; int mmq_x, mmq_y, nwarps;
@ -1780,8 +1779,7 @@ static void ggml_mul_mat_q4_1_q8_1_cuda(
const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x, const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x,
const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) { const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) {
int id; int id = ggml_cuda_get_device();
CUDA_CHECK(cudaGetDevice(&id));
const int compute_capability = ggml_cuda_info().devices[id].cc; const int compute_capability = ggml_cuda_info().devices[id].cc;
int mmq_x, mmq_y, nwarps; int mmq_x, mmq_y, nwarps;
@ -1825,8 +1823,7 @@ static void ggml_mul_mat_q5_0_q8_1_cuda(
const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x, const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x,
const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) { const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) {
int id; int id = ggml_cuda_get_device();
CUDA_CHECK(cudaGetDevice(&id));
const int compute_capability = ggml_cuda_info().devices[id].cc; const int compute_capability = ggml_cuda_info().devices[id].cc;
int mmq_x, mmq_y, nwarps; int mmq_x, mmq_y, nwarps;
@ -1870,8 +1867,7 @@ static void ggml_mul_mat_q5_1_q8_1_cuda(
const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x, const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x,
const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) { const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) {
int id; int id = ggml_cuda_get_device();
CUDA_CHECK(cudaGetDevice(&id));
const int compute_capability = ggml_cuda_info().devices[id].cc; const int compute_capability = ggml_cuda_info().devices[id].cc;
int mmq_x, mmq_y, nwarps; int mmq_x, mmq_y, nwarps;
@ -1915,8 +1911,7 @@ static void ggml_mul_mat_q8_0_q8_1_cuda(
const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x, const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x,
const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) { const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) {
int id; int id = ggml_cuda_get_device();
CUDA_CHECK(cudaGetDevice(&id));
const int compute_capability = ggml_cuda_info().devices[id].cc; const int compute_capability = ggml_cuda_info().devices[id].cc;
int mmq_x, mmq_y, nwarps; int mmq_x, mmq_y, nwarps;
@ -1960,8 +1955,7 @@ static void ggml_mul_mat_q2_K_q8_1_cuda(
const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x, const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x,
const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) { const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) {
int id; int id = ggml_cuda_get_device();
CUDA_CHECK(cudaGetDevice(&id));
const int compute_capability = ggml_cuda_info().devices[id].cc; const int compute_capability = ggml_cuda_info().devices[id].cc;
int mmq_x, mmq_y, nwarps; int mmq_x, mmq_y, nwarps;
@ -2007,8 +2001,7 @@ static void ggml_mul_mat_q3_K_q8_1_cuda(
#if QK_K == 256 #if QK_K == 256
int id; int id = ggml_cuda_get_device();
CUDA_CHECK(cudaGetDevice(&id));
const int compute_capability = ggml_cuda_info().devices[id].cc; const int compute_capability = ggml_cuda_info().devices[id].cc;
int mmq_x, mmq_y, nwarps; int mmq_x, mmq_y, nwarps;
@ -2053,8 +2046,7 @@ static void ggml_mul_mat_q4_K_q8_1_cuda(
const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x, const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x,
const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) { const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) {
int id; int id = ggml_cuda_get_device();
CUDA_CHECK(cudaGetDevice(&id));
const int compute_capability = ggml_cuda_info().devices[id].cc; const int compute_capability = ggml_cuda_info().devices[id].cc;
int mmq_x, mmq_y, nwarps; int mmq_x, mmq_y, nwarps;
@ -2098,8 +2090,7 @@ static void ggml_mul_mat_q5_K_q8_1_cuda(
const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x, const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x,
const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) { const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) {
int id; int id = ggml_cuda_get_device();
CUDA_CHECK(cudaGetDevice(&id));
const int compute_capability = ggml_cuda_info().devices[id].cc; const int compute_capability = ggml_cuda_info().devices[id].cc;
int mmq_x, mmq_y, nwarps; int mmq_x, mmq_y, nwarps;
@ -2143,8 +2134,7 @@ static void ggml_mul_mat_q6_K_q8_1_cuda(
const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x, const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x,
const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) { const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) {
int id; int id = ggml_cuda_get_device();
CUDA_CHECK(cudaGetDevice(&id));
const int compute_capability = ggml_cuda_info().devices[id].cc; const int compute_capability = ggml_cuda_info().devices[id].cc;
int mmq_x, mmq_y, nwarps; int mmq_x, mmq_y, nwarps;

View file

@ -89,8 +89,7 @@ static void mul_mat_vec_q_cuda(
GGML_ASSERT(ncols_x % qk == 0); GGML_ASSERT(ncols_x % qk == 0);
GGML_ASSERT(ncols_y <= MMVQ_MAX_BATCH_SIZE); GGML_ASSERT(ncols_y <= MMVQ_MAX_BATCH_SIZE);
int id; int id = ggml_cuda_get_device();
CUDA_CHECK(cudaGetDevice(&id));
int64_t nwarps = 1; int64_t nwarps = 1;
int64_t rows_per_cuda_block = 1; int64_t rows_per_cuda_block = 1;
@ -328,8 +327,7 @@ void ggml_cuda_op_mul_mat_vec_q(
const int64_t ne0 = dst->ne[0]; const int64_t ne0 = dst->ne[0];
int id; int id = ggml_cuda_get_device();
CUDA_CHECK(cudaGetDevice(&id));
// the main device has a larger memory buffer to hold the results from all GPUs // the main device has a larger memory buffer to hold the results from all GPUs
// nrows_dst == nrows of the matrix that the kernel writes into // nrows_dst == nrows of the matrix that the kernel writes into

View file

@ -28,5 +28,4 @@ void ggml_cuda_op_scale(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
memcpy(&scale, dst->op_params, sizeof(float)); memcpy(&scale, dst->op_params, sizeof(float));
scale_f32_cuda(src0_d, dst_d, scale, ggml_nelements(src0), stream); scale_f32_cuda(src0_d, dst_d, scale, ggml_nelements(src0), stream);
CUDA_CHECK(cudaGetLastError());
} }

View file

@ -265,11 +265,20 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){
static void * ggml_metal_host_malloc(size_t n) { static void * ggml_metal_host_malloc(size_t n) {
void * data = NULL; void * data = NULL;
#if TARGET_OS_OSX
kern_return_t err = vm_allocate((vm_map_t) mach_task_self(), (void *) &data, n, VM_FLAGS_ANYWHERE);
if (err != KERN_SUCCESS) {
GGML_METAL_LOG_ERROR("%s: error: vm_allocate failed\n", __func__);
return NULL;
}
#else
const int result = posix_memalign((void **) &data, sysconf(_SC_PAGESIZE), n); const int result = posix_memalign((void **) &data, sysconf(_SC_PAGESIZE), n);
if (result != 0) { if (result != 0) {
GGML_METAL_LOG_ERROR("%s: error: posix_memalign failed\n", __func__); GGML_METAL_LOG_ERROR("%s: error: posix_memalign failed\n", __func__);
return NULL; return NULL;
} }
#endif
return data; return data;
} }
@ -2840,7 +2849,11 @@ GGML_CALL static void ggml_backend_metal_buffer_free_buffer(ggml_backend_buffer_
ggml_backend_metal_free_device(); ggml_backend_metal_free_device();
if (ctx->owned) { if (ctx->owned) {
#if TARGET_OS_OSX
vm_deallocate((vm_map_t)mach_task_self(), (vm_address_t)ctx->all_data, ctx->all_size);
#else
free(ctx->all_data); free(ctx->all_data);
#endif
} }
free(ctx); free(ctx);
@ -2944,14 +2957,16 @@ GGML_CALL static ggml_backend_buffer_t ggml_backend_metal_buffer_type_alloc_buff
ctx->owned = true; ctx->owned = true;
ctx->n_buffers = 1; ctx->n_buffers = 1;
ctx->buffers[0].data = ctx->all_data; if (ctx->all_data != NULL) {
ctx->buffers[0].size = size; ctx->buffers[0].data = ctx->all_data;
ctx->buffers[0].metal = [device newBufferWithBytesNoCopy:ctx->all_data ctx->buffers[0].size = size;
length:size_aligned ctx->buffers[0].metal = [device newBufferWithBytesNoCopy:ctx->all_data
options:MTLResourceStorageModeShared length:size_aligned
deallocator:nil]; options:MTLResourceStorageModeShared
deallocator:nil];
}
if (ctx->buffers[0].metal == nil) { if (ctx->all_data == NULL || ctx->buffers[0].metal == nil) {
GGML_METAL_LOG_ERROR("%s: error: failed to allocate buffer, size = %8.2f MiB\n", __func__, size_aligned / 1024.0 / 1024.0); GGML_METAL_LOG_ERROR("%s: error: failed to allocate buffer, size = %8.2f MiB\n", __func__, size_aligned / 1024.0 / 1024.0);
free(ctx); free(ctx);
ggml_backend_metal_free_device(); ggml_backend_metal_free_device();

View file

@ -878,7 +878,7 @@ class GGUFValueType(IntEnum):
# Note: Does not support GGML_QKK_64 # Note: Does not support GGML_QKK_64
QK_K = 256 QK_K = 256
# Items here are (block size, type size) # Items here are (block size, type size)
GGML_QUANT_SIZES = { GGML_QUANT_SIZES: dict[GGMLQuantizationType, tuple[int, int]] = {
GGMLQuantizationType.F32: (1, 4), GGMLQuantizationType.F32: (1, 4),
GGMLQuantizationType.F16: (1, 2), GGMLQuantizationType.F16: (1, 2),
GGMLQuantizationType.Q4_0: (32, 2 + 16), GGMLQuantizationType.Q4_0: (32, 2 + 16),

View file

@ -65,7 +65,7 @@ class ReaderTensor(NamedTuple):
class GGUFReader: class GGUFReader:
# I - same as host, S - swapped # I - same as host, S - swapped
byte_order: Literal['I' | 'S'] = 'I' byte_order: Literal['I'] | Literal['S'] = 'I'
alignment: int = GGUF_DEFAULT_ALIGNMENT alignment: int = GGUF_DEFAULT_ALIGNMENT
# Note: Internal helper, API may change. # Note: Internal helper, API may change.
@ -83,7 +83,7 @@ class GGUFReader:
GGUFValueType.BOOL: np.bool_, GGUFValueType.BOOL: np.bool_,
} }
def __init__(self, path: os.PathLike[str] | str, mode: Literal['r' | 'r+' | 'c'] = 'r'): def __init__(self, path: os.PathLike[str] | str, mode: Literal['r'] | Literal['r+'] | Literal['c'] = 'r'):
self.data = np.memmap(path, mode = mode) self.data = np.memmap(path, mode = mode)
offs = 0 offs = 0
if self._get(offs, np.uint32, override_order = '<')[0] != GGUF_MAGIC: if self._get(offs, np.uint32, override_order = '<')[0] != GGUF_MAGIC:
@ -128,7 +128,7 @@ class GGUFReader:
return self.tensors[idx] return self.tensors[idx]
def _get( def _get(
self, offset: int, dtype: npt.DTypeLike, count: int = 1, override_order: None | Literal['I' | 'S' | '<'] = None, self, offset: int, dtype: npt.DTypeLike, count: int = 1, override_order: None | Literal['I'] | Literal['S'] | Literal['<'] = None,
) -> npt.NDArray[Any]: ) -> npt.NDArray[Any]:
count = int(count) count = int(count)
itemsize = int(np.empty([], dtype = dtype).itemsize) itemsize = int(np.empty([], dtype = dtype).itemsize)
@ -250,7 +250,7 @@ class GGUFReader:
raise ValueError(f'Found duplicated tensor with name {tensor_name}') raise ValueError(f'Found duplicated tensor with name {tensor_name}')
tensor_names.add(tensor_name) tensor_names.add(tensor_name)
ggml_type = GGMLQuantizationType(raw_dtype[0]) ggml_type = GGMLQuantizationType(raw_dtype[0])
n_elems = np.prod(dims) n_elems = int(np.prod(dims))
block_size, type_size = GGML_QUANT_SIZES[ggml_type] block_size, type_size = GGML_QUANT_SIZES[ggml_type]
n_bytes = n_elems * type_size // block_size n_bytes = n_elems * type_size // block_size
data_offs = int(start_offs + offset_tensor[0]) data_offs = int(start_offs + offset_tensor[0])

View file

@ -7,7 +7,7 @@ import struct
import tempfile import tempfile
from enum import Enum, auto from enum import Enum, auto
from io import BufferedWriter from io import BufferedWriter
from typing import IO, Any, Sequence, Mapping from typing import IO, Any, Callable, Sequence, Mapping
from string import ascii_letters, digits from string import ascii_letters, digits
import numpy as np import numpy as np
@ -28,6 +28,47 @@ from .constants import (
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class LazyTensor:
data: Callable[[], np.ndarray[Any, Any]]
# to avoid too deep recursion
functions: list[Callable[[np.ndarray[Any, Any]], np.ndarray[Any, Any]]]
dtype: np.dtype[Any]
shape: tuple[int, ...]
def __init__(self, data: Callable[[], np.ndarray[Any, Any]], *, dtype: type, shape: tuple[int, ...]):
self.data = data
self.functions = []
self.dtype = np.dtype(dtype)
self.shape = shape
def astype(self, dtype: type, **kwargs) -> LazyTensor:
self.functions.append(lambda n: n.astype(dtype, **kwargs))
self.dtype = np.dtype(dtype)
return self
@property
def nbytes(self) -> int:
size = 1
for n in self.shape:
size *= n
return size * self.dtype.itemsize
def tofile(self, *args, **kwargs) -> None:
data = self.data()
for f in self.functions:
data = f(data)
assert data.shape == self.shape
assert data.dtype == self.dtype
assert data.nbytes == self.nbytes
self.functions = []
self.data = lambda: data
data.tofile(*args, **kwargs)
def byteswap(self, *args, **kwargs) -> LazyTensor:
self.functions.append(lambda n: n.byteswap(*args, **kwargs))
return self
class WriterState(Enum): class WriterState(Enum):
EMPTY = auto() EMPTY = auto()
HEADER = auto() HEADER = auto()
@ -38,7 +79,7 @@ class WriterState(Enum):
class GGUFWriter: class GGUFWriter:
fout: BufferedWriter fout: BufferedWriter
temp_file: tempfile.SpooledTemporaryFile[bytes] | None temp_file: tempfile.SpooledTemporaryFile[bytes] | None
tensors: list[np.ndarray[Any, Any]] tensors: list[np.ndarray[Any, Any] | LazyTensor]
_simple_value_packing = { _simple_value_packing = {
GGUFValueType.UINT8: "B", GGUFValueType.UINT8: "B",
GGUFValueType.INT8: "b", GGUFValueType.INT8: "b",
@ -176,7 +217,7 @@ class GGUFWriter:
if pack_fmt is not None: if pack_fmt is not None:
self.kv_data += self._pack(pack_fmt, val, skip_pack_prefix = vtype == GGUFValueType.BOOL) self.kv_data += self._pack(pack_fmt, val, skip_pack_prefix = vtype == GGUFValueType.BOOL)
elif vtype == GGUFValueType.STRING: elif vtype == GGUFValueType.STRING:
encoded_val = val.encode("utf8") if isinstance(val, str) else val encoded_val = val.encode("utf-8") if isinstance(val, str) else val
self.kv_data += self._pack("Q", len(encoded_val)) self.kv_data += self._pack("Q", len(encoded_val))
self.kv_data += encoded_val self.kv_data += encoded_val
elif vtype == GGUFValueType.ARRAY and isinstance(val, Sequence) and val: elif vtype == GGUFValueType.ARRAY and isinstance(val, Sequence) and val:
@ -205,7 +246,7 @@ class GGUFWriter:
raise ValueError(f'Duplicated tensor name {name}') raise ValueError(f'Duplicated tensor name {name}')
self.ti_names.add(name) self.ti_names.add(name)
encoded_name = name.encode("utf8") encoded_name = name.encode("utf-8")
self.ti_data += self._pack("Q", len(encoded_name)) self.ti_data += self._pack("Q", len(encoded_name))
self.ti_data += encoded_name self.ti_data += encoded_name
n_dims = len(tensor_shape) n_dims = len(tensor_shape)
@ -237,7 +278,7 @@ class GGUFWriter:
self.ti_data_count += 1 self.ti_data_count += 1
def add_tensor( def add_tensor(
self, name: str, tensor: np.ndarray[Any, Any], raw_shape: Sequence[int] | None = None, self, name: str, tensor: np.ndarray[Any, Any] | LazyTensor, raw_shape: Sequence[int] | None = None,
raw_dtype: GGMLQuantizationType | None = None, raw_dtype: GGMLQuantizationType | None = None,
) -> None: ) -> None:
if self.endianess == GGUFEndian.BIG: if self.endianess == GGUFEndian.BIG:
@ -262,7 +303,7 @@ class GGUFWriter:
if pad != 0: if pad != 0:
fp.write(bytes([0] * pad)) fp.write(bytes([0] * pad))
def write_tensor_data(self, tensor: np.ndarray[Any, Any]) -> None: def write_tensor_data(self, tensor: np.ndarray[Any, Any] | LazyTensor) -> None:
if self.state is not WriterState.TI_DATA: if self.state is not WriterState.TI_DATA:
raise ValueError(f'Expected output file to contain tensor info, got {self.state}') raise ValueError(f'Expected output file to contain tensor info, got {self.state}')
@ -272,15 +313,33 @@ class GGUFWriter:
tensor.tofile(self.fout) tensor.tofile(self.fout)
self.write_padding(self.fout, tensor.nbytes) self.write_padding(self.fout, tensor.nbytes)
def write_tensors_to_file(self) -> None: def write_tensors_to_file(self, *, progress: bool = False) -> None:
self.write_ti_data_to_file() self.write_ti_data_to_file()
self.write_padding(self.fout, self.fout.tell()) self.write_padding(self.fout, self.fout.tell())
if self.temp_file is None: if self.temp_file is None:
self.tensors.reverse() # to pop from the "beginning" in constant time
if progress:
from tqdm import tqdm
total_bytes = sum(t.nbytes for t in self.tensors)
bar = tqdm(desc="Writing", total=total_bytes, unit="byte", unit_scale=True)
while True:
try:
tensor = self.tensors.pop()
except IndexError:
break
tensor.tofile(self.fout)
bar.update(tensor.nbytes)
self.write_padding(self.fout, tensor.nbytes)
return
while True: while True:
try: try:
tensor = self.tensors.pop(0) tensor = self.tensors.pop()
except IndexError: except IndexError:
break break
tensor.tofile(self.fout) tensor.tofile(self.fout)
@ -479,7 +538,7 @@ class GGUFWriter:
self.add_bool(Keys.Tokenizer.ADD_PREFIX, value) self.add_bool(Keys.Tokenizer.ADD_PREFIX, value)
def add_chat_template(self, value: str | Sequence[Mapping[str, str]]) -> None: def add_chat_template(self, value: str | Sequence[Mapping[str, str]]) -> None:
if isinstance(value, list): if not isinstance(value, str):
template_default = None template_default = None
template_names = set() template_names = set()

View file

@ -4,7 +4,7 @@ import logging
import json import json
import os import os
from pathlib import Path from pathlib import Path
from typing import Any, Callable from typing import Any, Callable, Sequence, Mapping, Iterable
from .gguf_writer import GGUFWriter from .gguf_writer import GGUFWriter
@ -15,11 +15,11 @@ class SpecialVocab:
merges: list[str] merges: list[str]
add_special_token: dict[str, bool] add_special_token: dict[str, bool]
special_token_ids: dict[str, int] special_token_ids: dict[str, int]
chat_template: str | None chat_template: str | Sequence[Mapping[str, str]] | None
def __init__( def __init__(
self, path: str | os.PathLike[str], load_merges: bool = False, self, path: str | os.PathLike[str], load_merges: bool = False,
special_token_types: tuple[str, ...] | None = None, special_token_types: Iterable[str] | None = None,
n_vocab: int | None = None, n_vocab: int | None = None,
): ):
self.special_token_ids = {} self.special_token_ids = {}

View file

@ -21,6 +21,7 @@ classifiers = [
[tool.poetry.dependencies] [tool.poetry.dependencies]
python = ">=3.8" python = ">=3.8"
numpy = ">=1.17" numpy = ">=1.17"
tqdm = ">=4.27"
[tool.poetry.dev-dependencies] [tool.poetry.dev-dependencies]
pytest = "^5.2" pytest = "^5.2"

View file

@ -47,7 +47,7 @@ def dump_metadata(reader: GGUFReader, args: argparse.Namespace) -> None:
if len(field.types) == 1: if len(field.types) == 1:
curr_type = field.types[0] curr_type = field.types[0]
if curr_type == GGUFValueType.STRING: if curr_type == GGUFValueType.STRING:
log_message += ' = {0}'.format(repr(str(bytes(field.parts[-1]), encoding='utf8')[:60])) log_message += ' = {0}'.format(repr(str(bytes(field.parts[-1]), encoding='utf-8')[:60]))
elif field.types[0] in reader.gguf_scalar_to_np: elif field.types[0] in reader.gguf_scalar_to_np:
log_message += ' = {0}'.format(field.parts[-1][0]) log_message += ' = {0}'.format(field.parts[-1][0])
print(log_message) # noqa: NP100 print(log_message) # noqa: NP100

View file

@ -7,7 +7,7 @@ import json
from pathlib import Path from pathlib import Path
import numpy as np import numpy as np
from typing import Any, Mapping, Sequence from typing import Any, Sequence
# Necessary to load the local gguf package # Necessary to load the local gguf package
if "NO_LOCAL_GGUF" not in os.environ and (Path(__file__).parent.parent.parent / 'gguf-py').exists(): if "NO_LOCAL_GGUF" not in os.environ and (Path(__file__).parent.parent.parent / 'gguf-py').exists():
@ -34,7 +34,7 @@ def get_byteorder(reader: gguf.GGUFReader) -> gguf.GGUFEndian:
return host_endian return host_endian
def decode_field(field: gguf.ReaderField) -> Any: def decode_field(field: gguf.ReaderField | None) -> Any:
if field and field.types: if field and field.types:
main_type = field.types[0] main_type = field.types[0]
@ -42,11 +42,11 @@ def decode_field(field: gguf.ReaderField) -> Any:
sub_type = field.types[-1] sub_type = field.types[-1]
if sub_type == gguf.GGUFValueType.STRING: if sub_type == gguf.GGUFValueType.STRING:
return [str(bytes(field.parts[idx]), encoding='utf8') for idx in field.data] return [str(bytes(field.parts[idx]), encoding='utf-8') for idx in field.data]
else: else:
return [pv for idx in field.data for pv in field.parts[idx].tolist()] return [pv for idx in field.data for pv in field.parts[idx].tolist()]
if main_type == gguf.GGUFValueType.STRING: if main_type == gguf.GGUFValueType.STRING:
return str(bytes(field.parts[-1]), encoding='utf8') return str(bytes(field.parts[-1]), encoding='utf-8')
else: else:
return field.parts[-1][0] return field.parts[-1][0]
@ -59,7 +59,7 @@ def get_field_data(reader: gguf.GGUFReader, key: str) -> Any:
return decode_field(field) return decode_field(field)
def copy_with_new_metadata(reader: gguf.GGUFReader, writer: gguf.GGUFWriter, new_metadata: Mapping[str, str], remove_metadata: Sequence[str]) -> None: def copy_with_new_metadata(reader: gguf.GGUFReader, writer: gguf.GGUFWriter, new_metadata: dict[str, str], remove_metadata: Sequence[str]) -> None:
for field in reader.fields.values(): for field in reader.fields.values():
# Suppress virtual fields and fields written by GGUFWriter # Suppress virtual fields and fields written by GGUFWriter
if field.name == gguf.Keys.General.ARCHITECTURE or field.name.startswith('GGUF.'): if field.name == gguf.Keys.General.ARCHITECTURE or field.name.startswith('GGUF.'):
@ -101,7 +101,7 @@ def copy_with_new_metadata(reader: gguf.GGUFReader, writer: gguf.GGUFWriter, new
for tensor in reader.tensors: for tensor in reader.tensors:
# Dimensions are written in reverse order, so flip them first # Dimensions are written in reverse order, so flip them first
shape = np.flipud(tensor.shape) shape = np.flipud(tensor.shape).tolist()
writer.add_tensor_info(tensor.name, shape, tensor.data.dtype, tensor.data.nbytes, tensor.tensor_type) writer.add_tensor_info(tensor.name, shape, tensor.data.dtype, tensor.data.nbytes, tensor.tensor_type)
writer.write_header_to_file() writer.write_header_to_file()

3
pyrightconfig.json Normal file
View file

@ -0,0 +1,3 @@
{
"extraPaths": ["gguf-py"],
}

View file

@ -1,3 +1,2 @@
-r ./requirements-convert.txt -r ./requirements-convert.txt
torch~=2.1.1 torch~=2.1.1
einops~=0.7.0

View file

@ -1,3 +1,2 @@
-r ./requirements-convert.txt -r ./requirements-convert.txt
torch~=2.1.1 torch~=2.1.1
einops~=0.7.0

View file

@ -1,5 +1,5 @@
numpy~=1.24.4 numpy~=1.24.4
sentencepiece~=0.1.98 sentencepiece~=0.2.0
transformers>=4.40.1,<5.0.0 transformers>=4.40.1,<5.0.0
gguf>=0.1.0 gguf>=0.1.0
protobuf>=4.21.0,<5.0.0 protobuf>=4.21.0,<5.0.0

View file

@ -2,6 +2,7 @@
#undef NDEBUG #undef NDEBUG
#endif #endif
#include <cassert>
#include <fstream> #include <fstream>
#include <sstream> #include <sstream>
#include <regex> #include <regex>