Merge remote-tracking branch 'origin/master' into json-fixes

This commit is contained in:
ochafik 2024-03-16 00:45:07 +00:00
commit 5602a8b649
23 changed files with 1011 additions and 145 deletions

View file

@ -554,7 +554,7 @@ endif
endif # LLAMA_METAL endif # LLAMA_METAL
ifdef LLAMA_METAL ifdef LLAMA_METAL
ggml-metal.o: ggml-metal.m ggml-metal.h ggml-metal.o: ggml-metal.m ggml-metal.h ggml.h
$(CC) $(CFLAGS) -c $< -o $@ $(CC) $(CFLAGS) -c $< -o $@
ifdef LLAMA_METAL_EMBED_LIBRARY ifdef LLAMA_METAL_EMBED_LIBRARY

View file

@ -112,6 +112,7 @@ Typically finetunes of the base models below are supported as well.
- [x] [CodeShell](https://github.com/WisdomShell/codeshell) - [x] [CodeShell](https://github.com/WisdomShell/codeshell)
- [x] [Gemma](https://ai.google.dev/gemma) - [x] [Gemma](https://ai.google.dev/gemma)
- [x] [Mamba](https://github.com/state-spaces/mamba) - [x] [Mamba](https://github.com/state-spaces/mamba)
- [x] [Command-R](https://huggingface.co/CohereForAI/c4ai-command-r-v01)
**Multimodal models:** **Multimodal models:**

View file

@ -568,6 +568,34 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
break; break;
} }
params.lora_base = argv[i]; params.lora_base = argv[i];
} else if (arg == "--control-vector") {
if (++i >= argc) {
invalid_param = true;
break;
}
params.control_vectors.push_back({ 1.0f, argv[i], });
} else if (arg == "--control-vector-scaled") {
if (++i >= argc) {
invalid_param = true;
break;
}
const char * fname = argv[i];
if (++i >= argc) {
invalid_param = true;
break;
}
params.control_vectors.push_back({ std::stof(argv[i]), fname, });
} else if (arg == "--control-vector-layer-range") {
if (++i >= argc) {
invalid_param = true;
break;
}
params.control_vector_layer_start = std::stoi(argv[i]);
if (++i >= argc) {
invalid_param = true;
break;
}
params.control_vector_layer_end = std::stoi(argv[i]);
} else if (arg == "--mmproj") { } else if (arg == "--mmproj") {
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;
@ -1095,6 +1123,12 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
printf(" --lora FNAME apply LoRA adapter (implies --no-mmap)\n"); printf(" --lora FNAME apply LoRA adapter (implies --no-mmap)\n");
printf(" --lora-scaled FNAME S apply LoRA adapter with user defined scaling S (implies --no-mmap)\n"); printf(" --lora-scaled FNAME S apply LoRA adapter with user defined scaling S (implies --no-mmap)\n");
printf(" --lora-base FNAME optional model to use as a base for the layers modified by the LoRA adapter\n"); printf(" --lora-base FNAME optional model to use as a base for the layers modified by the LoRA adapter\n");
printf(" --control-vector FNAME\n");
printf(" add a control vector\n");
printf(" --control-vector-scaled FNAME S\n");
printf(" add a control vector with user defined scaling S\n");
printf(" --control-vector-layer-range START END\n");
printf(" layer range to apply the control vector(s) to, start and end inclusive\n");
printf(" -m FNAME, --model FNAME\n"); printf(" -m FNAME, --model FNAME\n");
printf(" model path (default: %s)\n", params.model.c_str()); printf(" model path (default: %s)\n", params.model.c_str());
printf(" -md FNAME, --model-draft FNAME\n"); printf(" -md FNAME, --model-draft FNAME\n");
@ -1360,6 +1394,30 @@ std::tuple<struct llama_model *, struct llama_context *> llama_init_from_gpt_par
return std::make_tuple(nullptr, nullptr); return std::make_tuple(nullptr, nullptr);
} }
if (!params.control_vectors.empty()) {
if (params.control_vector_layer_start <= 0) params.control_vector_layer_start = 1;
if (params.control_vector_layer_end <= 0) params.control_vector_layer_end = llama_n_layer(model);
const auto cvec = llama_control_vector_load(params.control_vectors);
if (cvec.n_embd == -1) {
llama_free(lctx);
llama_free_model(model);
return std::make_tuple(nullptr, nullptr);
}
int err = llama_control_vector_apply(lctx,
cvec.data.data(),
cvec.data.size(),
cvec.n_embd,
params.control_vector_layer_start,
params.control_vector_layer_end);
if (err) {
llama_free(lctx);
llama_free_model(model);
return std::make_tuple(nullptr, nullptr);
}
}
for (unsigned int i = 0; i < params.lora_adapter.size(); ++i) { for (unsigned int i = 0; i < params.lora_adapter.size(); ++i) {
const std::string& lora_adapter = std::get<0>(params.lora_adapter[i]); const std::string& lora_adapter = std::get<0>(params.lora_adapter[i]);
float lora_scale = std::get<1>(params.lora_adapter[i]); float lora_scale = std::get<1>(params.lora_adapter[i]);
@ -1890,3 +1948,160 @@ float llama_embd_similarity_cos(const float * embd1, const float * embd2, int n)
return sum / (sqrt(sum1) * sqrt(sum2)); return sum / (sqrt(sum1) * sqrt(sum2));
} }
//
// Control vector utils
//
static llama_control_vector_data llama_control_vector_load_one(const llama_control_vector_load_info & load_info) {
int32_t n_tensors;
size_t n_bytes = 0;
uint32_t max_direction_layer = 0;
llama_control_vector_data result = { -1, {} };
// calculate size of ctx needed for tensors, ensure tensors are f32, and find max layer
{
struct ggml_init_params meta_params = {
/* .mem_size = */ ggml_tensor_overhead() * 128 + ggml_graph_overhead(),
/* .mem_buffer = */ nullptr,
/* .no_alloc = */ true,
};
ggml_context * meta_ctx = ggml_init(meta_params);
struct gguf_init_params meta_gguf_params = {
/* .no_alloc = */ true,
/* .ctx = */ &meta_ctx,
};
struct gguf_context * meta_ctx_gguf = gguf_init_from_file(load_info.fname.c_str(), meta_gguf_params);
if (!meta_ctx_gguf) {
fprintf(stderr, "%s: failed to load control vector from %s\n", __func__, load_info.fname.c_str());
ggml_free(meta_ctx);
return result;
}
n_tensors = gguf_get_n_tensors(meta_ctx_gguf);
for (int i = 0; i < n_tensors; i++) {
std::string name = gguf_get_tensor_name(meta_ctx_gguf, i);
// split on '.'
size_t dotpos = name.find('.');
if (dotpos != std::string::npos && name.substr(0, dotpos) == "direction") {
try {
uint32_t layer = std::stoi(name.substr(dotpos + 1));
if (layer == 0) {
fprintf(stderr, "%s: direction tensor invalid in %s\n", __func__, load_info.fname.c_str());
ggml_free(meta_ctx);
gguf_free(meta_ctx_gguf);
return result;
}
if (layer > max_direction_layer) {
max_direction_layer = layer;
}
} catch (...) {
fprintf(stderr, "%s: direction tensor invalid in %s\n", __func__, load_info.fname.c_str());
ggml_free(meta_ctx);
gguf_free(meta_ctx_gguf);
return result;
}
}
struct ggml_tensor * tensor_meta = ggml_get_tensor(meta_ctx, name.c_str());
if (tensor_meta->type != GGML_TYPE_F32 || ggml_n_dims(tensor_meta) != 1) {
fprintf(stderr, "%s: direction tensor invalid in %s\n", __func__, load_info.fname.c_str());
ggml_free(meta_ctx);
gguf_free(meta_ctx_gguf);
return result;
}
if (result.n_embd == -1) {
result.n_embd = ggml_nelements(tensor_meta);
} else if (ggml_nelements(tensor_meta) != result.n_embd) {
fprintf(stderr, "%s: direction tensor sizes mismatched in %s\n", __func__, load_info.fname.c_str());
ggml_free(meta_ctx);
gguf_free(meta_ctx_gguf);
return result;
}
n_bytes += ggml_nbytes(tensor_meta);
}
ggml_free(meta_ctx);
gguf_free(meta_ctx_gguf);
}
if (n_tensors == 0) {
fprintf(stderr, "%s: no direction tensors found in %s\n", __func__, load_info.fname.c_str());
return result;
}
// load and scale tensors into final control vector context
struct ggml_init_params ggml_params = {
/* .mem_size = */ ggml_tensor_overhead() * n_tensors + n_bytes,
/* .mem_buffer = */ nullptr,
/* .no_alloc = */ false,
};
struct ggml_context * ctx = ggml_init(ggml_params);
struct gguf_init_params params = {
/*.no_alloc = */ false,
/*.ctx = */ &ctx,
};
struct gguf_context * ctx_gguf = gguf_init_from_file(load_info.fname.c_str(), params);
if (!ctx_gguf) {
fprintf(stderr, "%s: failed to load control vector from %s\n", __func__, load_info.fname.c_str());
ggml_free(ctx);
return result;
}
// do not store data for layer 0 (it's not used)
result.data.resize(result.n_embd * max_direction_layer);
for (uint32_t il = 1; il <= max_direction_layer; il++) {
const std::string name = "direction." + std::to_string(il);
const ggml_tensor * tensor = ggml_get_tensor(ctx, name.c_str());
float * dst = result.data.data() + result.n_embd * (il - 1);
if (tensor) {
const float * src = (const float *) tensor->data;
for (int j = 0; j < result.n_embd; j++) {
dst[j] = src[j] * load_info.strength;
}
} else {
for (int j = 0; j < result.n_embd; j++) {
dst[j] = 0.0f;
}
}
}
return result;
}
llama_control_vector_data llama_control_vector_load(const std::vector<llama_control_vector_load_info> & load_infos) {
llama_control_vector_data result = { -1, {} };
for (const auto & info : load_infos) {
auto cur = llama_control_vector_load_one(info);
if (cur.n_embd == -1) {
return result;
}
if (result.n_embd != -1 && (result.n_embd != cur.n_embd || result.data.size() != cur.data.size())) {
fprintf(stderr, "%s: control vector in %s does not match previous vector dimensions\n", __func__, info.fname.c_str());
return result;
}
if (result.n_embd == -1) {
result = std::move(cur);
} else {
for (size_t i = 0; i < cur.data.size(); i++) {
result.data[i] += cur.data[i];
}
}
}
if (result.n_embd == -1) {
fprintf(stderr, "%s: no vectors passed\n", __func__);
}
return result;
}

View file

@ -37,10 +37,13 @@ extern char const *LLAMA_COMMIT;
extern char const *LLAMA_COMPILER; extern char const *LLAMA_COMPILER;
extern char const *LLAMA_BUILD_TARGET; extern char const *LLAMA_BUILD_TARGET;
struct llama_control_vector_load_info;
int32_t get_num_physical_cores();
// //
// CLI argument parsing // CLI argument parsing
// //
int32_t get_num_physical_cores();
struct gpt_params { struct gpt_params {
uint32_t seed = LLAMA_DEFAULT_SEED; // RNG seed uint32_t seed = LLAMA_DEFAULT_SEED; // RNG seed
@ -103,6 +106,11 @@ struct gpt_params {
std::vector<std::tuple<std::string, float>> lora_adapter; // lora adapter path with user defined scale std::vector<std::tuple<std::string, float>> lora_adapter; // lora adapter path with user defined scale
std::string lora_base = ""; // base model path for the lora adapter std::string lora_base = ""; // base model path for the lora adapter
std::vector<llama_control_vector_load_info> control_vectors; // control vector with user defined scale
int32_t control_vector_layer_start = -1; // layer range for control vector
int32_t control_vector_layer_end = -1; // layer range for control vector
int ppl_stride = 0; // stride for perplexity calculations. If left at 0, the pre-existing approach will be used. int ppl_stride = 0; // stride for perplexity calculations. If left at 0, the pre-existing approach will be used.
int ppl_output_type = 0; // = 0 -> ppl output is as usual, = 1 -> ppl output is num_tokens, ppl, one per line int ppl_output_type = 0; // = 0 -> ppl output is as usual, = 1 -> ppl output is num_tokens, ppl, one per line
// (which is more convenient to use for plotting) // (which is more convenient to use for plotting)
@ -269,3 +277,24 @@ void dump_kv_cache_view_seqs(const llama_kv_cache_view & view, int row_size = 40
void llama_embd_normalize(const float * inp, float * out, int n); void llama_embd_normalize(const float * inp, float * out, int n);
float llama_embd_similarity_cos(const float * embd1, const float * embd2, int n); float llama_embd_similarity_cos(const float * embd1, const float * embd2, int n);
//
// Control vector utils
//
struct llama_control_vector_data {
int n_embd;
// stores data for layers [1, n_layer] where n_layer = data.size() / n_embd
std::vector<float> data;
};
struct llama_control_vector_load_info {
float strength;
std::string fname;
};
// Load control vectors, scale each by strength, and add them together.
// On error, returns {-1, empty}
llama_control_vector_data llama_control_vector_load(const std::vector<llama_control_vector_load_info> & load_infos);

View file

@ -1965,6 +1965,23 @@ class MambaModel(Model):
self.gguf_writer.add_tensor(new_name, data) self.gguf_writer.add_tensor(new_name, data)
@Model.register("CohereForCausalLM")
class CommandR2Model(Model):
model_arch = gguf.MODEL_ARCH.COMMAND_R
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# max_position_embeddings = 8192 in config.json but model was actually
# trained on 128k context length
self.hparams["max_position_embeddings"] = self.hparams["model_max_length"]
def set_gguf_parameters(self):
super().set_gguf_parameters()
self.gguf_writer.add_logit_scale(self.hparams["logit_scale"])
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.NONE)
###### CONVERSION LOGIC ###### ###### CONVERSION LOGIC ######

View file

@ -8,6 +8,7 @@
#include <cstdio> #include <cstdio>
#include <cstring> #include <cstring>
#include <ctime> #include <ctime>
#include <cstdlib>
#include <iterator> #include <iterator>
#include <map> #include <map>
#include <numeric> #include <numeric>
@ -1123,15 +1124,19 @@ struct sql_printer : public printer {
static void test_prompt(llama_context * ctx, int n_prompt, int n_past, int n_batch, int n_threads) { static void test_prompt(llama_context * ctx, int n_prompt, int n_past, int n_batch, int n_threads) {
llama_set_n_threads(ctx, n_threads, n_threads); llama_set_n_threads(ctx, n_threads, n_threads);
//std::vector<llama_token> tokens(n_prompt, llama_token_bos(llama_get_model(ctx))); const llama_model * model = llama_get_model(ctx);
//llama_decode(ctx, llama_batch_get_one(tokens.data(), n_prompt, n_past, 0)); const int32_t n_vocab = llama_n_vocab(model);
//GGML_UNUSED(n_batch);
std::vector<llama_token> tokens(n_batch);
std::vector<llama_token> tokens(n_batch, llama_token_bos(llama_get_model(ctx)));
int n_processed = 0; int n_processed = 0;
while (n_processed < n_prompt) { while (n_processed < n_prompt) {
int n_tokens = std::min(n_prompt - n_processed, n_batch); int n_tokens = std::min(n_prompt - n_processed, n_batch);
tokens[0] = n_processed == 0 && llama_add_bos_token(model) ? llama_token_bos(model) : std::rand() % n_vocab;
for (int i = 1; i < n_tokens; i++) {
tokens[i] = std::rand() % n_vocab;
}
llama_decode(ctx, llama_batch_get_one(tokens.data(), n_tokens, n_past + n_processed, 0)); llama_decode(ctx, llama_batch_get_one(tokens.data(), n_tokens, n_past + n_processed, 0));
n_processed += n_tokens; n_processed += n_tokens;
} }
@ -1142,11 +1147,15 @@ static void test_prompt(llama_context * ctx, int n_prompt, int n_past, int n_bat
static void test_gen(llama_context * ctx, int n_gen, int n_past, int n_threads) { static void test_gen(llama_context * ctx, int n_gen, int n_past, int n_threads) {
llama_set_n_threads(ctx, n_threads, n_threads); llama_set_n_threads(ctx, n_threads, n_threads);
llama_token token = llama_token_bos(llama_get_model(ctx)); const llama_model * model = llama_get_model(ctx);
const int32_t n_vocab = llama_n_vocab(model);
llama_token token = llama_add_bos_token(model) ? llama_token_bos(model) : std::rand() % n_vocab;
for (int i = 0; i < n_gen; i++) { for (int i = 0; i < n_gen; i++) {
llama_decode(ctx, llama_batch_get_one(&token, 1, n_past + i, 0)); llama_decode(ctx, llama_batch_get_one(&token, 1, n_past + i, 0));
llama_synchronize(ctx); llama_synchronize(ctx);
token = std::rand() % n_vocab;
} }
} }

View file

@ -1235,16 +1235,16 @@ struct clip_image_f32 * clip_image_f32_init() {
void clip_image_u8_free(struct clip_image_u8 * img) { delete img; } void clip_image_u8_free(struct clip_image_u8 * img) { delete img; }
void clip_image_f32_free(struct clip_image_f32 * img) { delete img; } void clip_image_f32_free(struct clip_image_f32 * img) { delete img; }
void clip_image_u8_batch_free(struct clip_image_u8_batch & batch) { void clip_image_u8_batch_free(struct clip_image_u8_batch * batch) {
if (batch.size > 0) { if (batch->size > 0) {
delete[] batch.data; delete[] batch->data;
batch.size = 0; batch->size = 0;
} }
} }
void clip_image_f32_batch_free(struct clip_image_f32_batch & batch) { void clip_image_f32_batch_free(struct clip_image_f32_batch * batch) {
if (batch.size > 0) { if (batch->size > 0) {
delete[] batch.data; delete[] batch->data;
batch.size = 0; batch->size = 0;
} }
} }
@ -1497,7 +1497,7 @@ static std::vector<clip_image_u8*> divide_to_patches_u8(const clip_image_u8 & im
// returns the normalized float tensor for llava-1.5, for spatial_unpad with anyres processing for llava-1.6 it returns the normalized image patch tensors as a vector // returns the normalized float tensor for llava-1.5, for spatial_unpad with anyres processing for llava-1.6 it returns the normalized image patch tensors as a vector
// res_imgs memory is being allocated here, previous allocations will be freed if found // res_imgs memory is being allocated here, previous allocations will be freed if found
bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, clip_image_f32_batch & res_imgs) { bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, clip_image_f32_batch * res_imgs) {
bool pad_to_square = true; bool pad_to_square = true;
if (!ctx->has_vision_encoder) { if (!ctx->has_vision_encoder) {
printf("This gguf file seems to have no vision encoder\n"); printf("This gguf file seems to have no vision encoder\n");
@ -1509,11 +1509,11 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, cli
pad_to_square = false; pad_to_square = false;
} }
// free the previous res_imgs if any set // free the previous res_imgs if any set
if (res_imgs.size > 0) { if (res_imgs->size > 0) {
clip_image_f32_batch_free(res_imgs); clip_image_f32_batch_free(res_imgs);
} }
res_imgs.data = nullptr; res_imgs->data = nullptr;
res_imgs.size = 0; res_imgs->size = 0;
// the logic below is to pad the shorter side to the longer side with a background color: rgb(122, 116, 104) // the logic below is to pad the shorter side to the longer side with a background color: rgb(122, 116, 104)
// see https://github.com/haotian-liu/LLaVA/blob/e854a2bf85118c504f6f16bf5c3c7c92f8fa8c6b/llava/conversation.py#L113-L156 // see https://github.com/haotian-liu/LLaVA/blob/e854a2bf85118c504f6f16bf5c3c7c92f8fa8c6b/llava/conversation.py#L113-L156
@ -1568,11 +1568,11 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, cli
bicubic_resize(*img, *image_original_resize, params.image_size, params.image_size); // in python this is "shortest_edge", but all CLIP are square bicubic_resize(*img, *image_original_resize, params.image_size, params.image_size); // in python this is "shortest_edge", but all CLIP are square
patches.insert(patches.begin(), image_original_resize); patches.insert(patches.begin(), image_original_resize);
// clip_image_f32_batch_init(patches.size()); // clip_image_f32_batch_init(patches.size());
res_imgs.size = patches.size(); res_imgs->size = patches.size();
res_imgs.data = new clip_image_f32[res_imgs.size]; res_imgs->data = new clip_image_f32[res_imgs->size];
int num=0; int num=0;
for (auto& patch : patches) { for (auto& patch : patches) {
normalize_image_u8_to_f32(patch, &res_imgs.data[num], ctx->image_mean, ctx->image_std); normalize_image_u8_to_f32(patch, &res_imgs->data[num], ctx->image_mean, ctx->image_std);
num++; num++;
} }
@ -1660,9 +1660,9 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, cli
// } // }
// res_imgs.push_back(res); // res_imgs.push_back(res);
res_imgs.size = 1; res_imgs->size = 1;
res_imgs.data = new clip_image_f32[res_imgs.size]; res_imgs->data = new clip_image_f32[res_imgs->size];
res_imgs.data[0] = *res; res_imgs->data[0] = *res;
clip_image_f32_free(res); clip_image_f32_free(res);
return true; return true;

View file

@ -60,8 +60,8 @@ CLIP_API struct clip_image_f32 * clip_image_f32_init();
CLIP_API void clip_image_u8_free (struct clip_image_u8 * img); CLIP_API void clip_image_u8_free (struct clip_image_u8 * img);
CLIP_API void clip_image_f32_free(struct clip_image_f32 * img); CLIP_API void clip_image_f32_free(struct clip_image_f32 * img);
CLIP_API void clip_image_u8_batch_free (struct clip_image_u8_batch & batch); CLIP_API void clip_image_u8_batch_free (struct clip_image_u8_batch * batch);
CLIP_API void clip_image_f32_batch_free(struct clip_image_f32_batch & batch); CLIP_API void clip_image_f32_batch_free(struct clip_image_f32_batch * batch);
CLIP_API bool clip_image_load_from_file(const char * fname, struct clip_image_u8 * img); CLIP_API bool clip_image_load_from_file(const char * fname, struct clip_image_u8 * img);
@ -69,7 +69,7 @@ CLIP_API bool clip_image_load_from_file(const char * fname, struct clip_image_u8
CLIP_API bool clip_image_load_from_bytes(const unsigned char * bytes, size_t bytes_length, struct clip_image_u8 * img); CLIP_API bool clip_image_load_from_bytes(const unsigned char * bytes, size_t bytes_length, struct clip_image_u8 * img);
/** preprocess img and store the result in res_imgs, pad_to_square may be overriden to false depending on model configuration */ /** preprocess img and store the result in res_imgs, pad_to_square may be overriden to false depending on model configuration */
CLIP_API bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, clip_image_f32_batch & res_imgs ); CLIP_API bool clip_image_preprocess(struct clip_ctx * ctx, const struct clip_image_u8 * img, struct clip_image_f32_batch * res_imgs );
CLIP_API struct ggml_tensor * clip_get_newline_tensor(const struct clip_ctx * ctx); CLIP_API struct ggml_tensor * clip_get_newline_tensor(const struct clip_ctx * ctx);

View file

@ -223,7 +223,7 @@ static bool encode_image_with_clip(clip_ctx * ctx_clip, int n_threads, const cli
clip_image_f32_batch img_res_v; clip_image_f32_batch img_res_v;
img_res_v.size = 0; img_res_v.size = 0;
img_res_v.data = nullptr; img_res_v.data = nullptr;
if (!clip_image_preprocess(ctx_clip, img, img_res_v)) { if (!clip_image_preprocess(ctx_clip, img, &img_res_v)) {
fprintf(stderr, "%s: unable to preprocess image\n", __func__); fprintf(stderr, "%s: unable to preprocess image\n", __func__);
delete[] img_res_v.data; delete[] img_res_v.data;
return false; return false;

View file

@ -29,9 +29,9 @@ struct llava_image_embed {
}; };
/** sanity check for clip <-> llava embed size match */ /** sanity check for clip <-> llava embed size match */
LLAVA_API bool llava_validate_embed_size(const llama_context * ctx_llama, const clip_ctx * ctx_clip); LLAVA_API bool llava_validate_embed_size(const struct llama_context * ctx_llama, const struct clip_ctx * ctx_clip);
LLAVA_API bool llava_image_embed_make_with_clip_img(clip_ctx * ctx_clip, int n_threads, const clip_image_u8 * img, float ** image_embd_out, int * n_img_pos_out); LLAVA_API bool llava_image_embed_make_with_clip_img(struct clip_ctx * ctx_clip, int n_threads, const struct clip_image_u8 * img, float ** image_embd_out, int * n_img_pos_out);
/** build an image embed from image file bytes */ /** build an image embed from image file bytes */
LLAVA_API struct llava_image_embed * llava_image_embed_make_with_bytes(struct clip_ctx * ctx_clip, int n_threads, const unsigned char * image_bytes, int image_bytes_length); LLAVA_API struct llava_image_embed * llava_image_embed_make_with_bytes(struct clip_ctx * ctx_clip, int n_threads, const unsigned char * image_bytes, int image_bytes_length);

View file

@ -13,8 +13,11 @@ source /opt/intel/oneapi/setvars.sh
#for FP32 #for FP32
cmake .. -DLLAMA_SYCL=ON -DCMAKE_C_COMPILER=icx -DCMAKE_CXX_COMPILER=icpx cmake .. -DLLAMA_SYCL=ON -DCMAKE_C_COMPILER=icx -DCMAKE_CXX_COMPILER=icpx
#build example/main only #build example/main
#cmake --build . --config Release --target main #cmake --build . --config Release --target main
#build example/llama-bench
#cmake --build . --config Release --target llama-bench
#build all binary #build all binary
cmake --build . --config Release -v cmake --build . --config Release -v

View file

@ -9,18 +9,28 @@ source /opt/intel/oneapi/setvars.sh
if [ $# -gt 0 ]; then if [ $# -gt 0 ]; then
GGML_SYCL_DEVICE=$1 GGML_SYCL_DEVICE=$1
GGML_SYCL_SINGLE_GPU=1
else else
GGML_SYCL_DEVICE=0 GGML_SYCL_DEVICE=0
fi fi
echo "use $GGML_SYCL_DEVICE as main GPU"
#export GGML_SYCL_DEBUG=1 #export GGML_SYCL_DEBUG=1
#ZES_ENABLE_SYSMAN=1, Support to get free memory of GPU by sycl::aspect::ext_intel_free_memory. Recommended to use when --split-mode = layer. #ZES_ENABLE_SYSMAN=1, Support to get free memory of GPU by sycl::aspect::ext_intel_free_memory. Recommended to use when --split-mode = layer.
#use all GPUs with same max compute units if [ $GGML_SYCL_SINGLE_GPU -eq 1 ]; then
ZES_ENABLE_SYSMAN=1 ./build/bin/main -m models/llama-2-7b.Q4_0.gguf -p "${INPUT2}" -n 400 -e -ngl 33 -s 0 echo "use $GGML_SYCL_DEVICE as main GPU"
#use signle GPU only
ZES_ENABLE_SYSMAN=1 ./build/bin/main -m models/llama-2-7b.Q4_0.gguf -p "${INPUT2}" -n 400 -e -ngl 33 -s 0 -mg $GGML_SYCL_DEVICE -sm none
else
#use multiple GPUs with same max compute units
ZES_ENABLE_SYSMAN=1 ./build/bin/main -m models/llama-2-7b.Q4_0.gguf -p "${INPUT2}" -n 400 -e -ngl 33 -s 0
fi
#use main GPU only #use main GPU only
#ZES_ENABLE_SYSMAN=1 ./build/bin/main -m models/llama-2-7b.Q4_0.gguf -p "${INPUT2}" -n 400 -e -ngl 33 -s 0 -mg $GGML_SYCL_DEVICE -sm none #ZES_ENABLE_SYSMAN=1 ./build/bin/main -m models/llama-2-7b.Q4_0.gguf -p "${INPUT2}" -n 400 -e -ngl 33 -s 0 -mg $GGML_SYCL_DEVICE -sm none
#use multiple GPUs with same max compute units
#ZES_ENABLE_SYSMAN=1 ./build/bin/main -m models/llama-2-7b.Q4_0.gguf -p "${INPUT2}" -n 400 -e -ngl 33 -s 0

View file

@ -11541,6 +11541,7 @@ static void ggml_backend_cuda_event_wait(ggml_backend_t backend, ggml_backend_ev
if (ggml_backend_is_cuda(event->backend)) { if (ggml_backend_is_cuda(event->backend)) {
CUDA_CHECK(cudaStreamWaitEvent(g_cudaStreams[cuda_ctx->device][0], (cudaEvent_t)event->context, 0)); CUDA_CHECK(cudaStreamWaitEvent(g_cudaStreams[cuda_ctx->device][0], (cudaEvent_t)event->context, 0));
} else { } else {
#if 0
// untested // untested
auto wait_fn = [](void * user_data) { auto wait_fn = [](void * user_data) {
ggml_backend_event_t event = (ggml_backend_event_t)user_data; ggml_backend_event_t event = (ggml_backend_event_t)user_data;
@ -11548,6 +11549,8 @@ static void ggml_backend_cuda_event_wait(ggml_backend_t backend, ggml_backend_ev
}; };
CUDA_CHECK(cudaLaunchHostFunc(g_cudaStreams[cuda_ctx->device][0], wait_fn, event)); CUDA_CHECK(cudaLaunchHostFunc(g_cudaStreams[cuda_ctx->device][0], wait_fn, event));
#endif
GGML_ASSERT(false);
} }
} }

View file

@ -16,6 +16,7 @@
#include <cinttypes> #include <cinttypes>
#include <cstddef> #include <cstddef>
#include <cstdint> #include <cstdint>
#include <cstdlib>
#include <float.h> #include <float.h>
#include <limits> #include <limits>
#include <stdint.h> #include <stdint.h>
@ -24,10 +25,9 @@
#include <cmath> #include <cmath>
#include <iostream> #include <iostream>
#include <fstream> #include <fstream>
#include <stdio.h> #include <stdio.h>
#include <stdlib.h> #include <stdlib.h>
#include <regex>
#include <sycl/sycl.hpp> #include <sycl/sycl.hpp>
#include <sycl/half_type.hpp> #include <sycl/half_type.hpp>
@ -82,6 +82,30 @@ Following definition copied from DPCT head files, which are used by ggml-sycl.cp
#define __dpct_noinline__ __attribute__((noinline)) #define __dpct_noinline__ __attribute__((noinline))
#endif #endif
std::string get_device_type_name(const sycl::device &Device) {
auto DeviceType = Device.get_info<sycl::info::device::device_type>();
switch (DeviceType) {
case sycl::info::device_type::cpu:
return "cpu";
case sycl::info::device_type::gpu:
return "gpu";
case sycl::info::device_type::host:
return "host";
case sycl::info::device_type::accelerator:
return "acc";
default:
return "unknown";
}
}
std::string get_device_backend_and_type(const sycl::device &device) {
std::stringstream device_type;
sycl::backend backend = device.get_backend();
device_type << backend << ":" << get_device_type_name(device);
return device_type.str();
}
namespace dpct namespace dpct
{ {
typedef sycl::queue *queue_ptr; typedef sycl::queue *queue_ptr;
@ -942,17 +966,65 @@ namespace dpct
private: private:
mutable std::recursive_mutex m_mutex; mutable std::recursive_mutex m_mutex;
static bool compare_dev(sycl::device &device1, sycl::device &device2)
{
dpct::device_info prop1;
dpct::get_device_info(prop1, device1);
dpct::device_info prop2;
dpct::get_device_info(prop2, device2);
return prop1.get_max_compute_units() > prop2.get_max_compute_units();
}
static int convert_backend_index(std::string & backend) {
if (backend == "ext_oneapi_level_zero:gpu") return 0;
if (backend == "opencl:gpu") return 1;
if (backend == "opencl:cpu") return 2;
if (backend == "opencl:acc") return 3;
printf("convert_backend_index: can't handle backend=%s\n", backend.c_str());
GGML_ASSERT(false);
}
static bool compare_backend(std::string &backend1, std::string &backend2) {
return convert_backend_index(backend1) < convert_backend_index(backend2);
}
dev_mgr() dev_mgr()
{ {
sycl::device default_device = sycl::device default_device =
sycl::device(sycl::default_selector_v); sycl::device(sycl::default_selector_v);
_devs.push_back(std::make_shared<device_ext>(default_device)); _devs.push_back(std::make_shared<device_ext>(default_device));
std::vector<sycl::device> sycl_all_devs = std::vector<sycl::device> sycl_all_devs;
sycl::device::get_devices(sycl::info::device_type::all);
// Collect other devices except for the default device. // Collect other devices except for the default device.
if (default_device.is_cpu()) if (default_device.is_cpu())
_cpu_device = 0; _cpu_device = 0;
auto Platforms = sycl::platform::get_platforms();
// Keep track of the number of devices per backend
std::map<sycl::backend, size_t> DeviceNums;
std::map<std::string, std::vector<sycl::device>> backend_devices;
while (!Platforms.empty()) {
auto Platform = Platforms.back();
Platforms.pop_back();
auto devices = Platform.get_devices();
std::string backend_type = get_device_backend_and_type(devices[0]);
for (const auto &device : devices) {
backend_devices[backend_type].push_back(device);
}
}
std::vector<std::string> keys;
for(auto it = backend_devices.begin(); it != backend_devices.end(); ++it) {
keys.push_back(it->first);
}
std::sort(keys.begin(), keys.end(), compare_backend);
for (auto &key : keys) {
std::vector<sycl::device> devs = backend_devices[key];
std::sort(devs.begin(), devs.end(), compare_dev);
for (const auto &dev : devs) {
sycl_all_devs.push_back(dev);
}
}
for (auto &dev : sycl_all_devs) for (auto &dev : sycl_all_devs)
{ {
if (dev == default_device) if (dev == default_device)
@ -3202,6 +3274,11 @@ static int g_work_group_size = 0;
#define GGML_SYCL_MMV_Y 1 #define GGML_SYCL_MMV_Y 1
#endif #endif
enum ggml_sycl_backend_gpu_mode {
SYCL_UNSET_GPU_MODE = -1,
SYCL_SINGLE_GPU_MODE = 0,
SYCL_MUL_GPU_MODE
};
static_assert(sizeof(sycl::half) == sizeof(ggml_fp16_t), "wrong fp16 size"); static_assert(sizeof(sycl::half) == sizeof(ggml_fp16_t), "wrong fp16 size");
@ -3401,12 +3478,31 @@ class sycl_gpu_mgr {
int work_group_size = 0; int work_group_size = 0;
std::string gpus_list = ""; std::string gpus_list = "";
/*
Use all GPUs with same top max compute units
*/
sycl_gpu_mgr() { sycl_gpu_mgr() {
detect_sycl_gpu_list_with_max_cu(); detect_sycl_gpu_list_with_max_cu();
get_allow_gpus(); get_allow_gpus();
create_context_with_gpus(); create_context_with_gpus();
} }
/*
Only use the assigned GPU
*/
sycl_gpu_mgr(int main_gpu_id) {
sycl::device device = dpct::dev_mgr::instance().get_device(main_gpu_id);
dpct::device_info prop;
dpct::get_device_info(prop, device);
gpus.push_back(main_gpu_id);
devices.push_back(device);
work_group_size = prop.get_max_work_group_size();
max_compute_units = prop.get_max_compute_units();
get_allow_gpus();
create_context_with_gpus();
}
void create_context_with_gpus() { void create_context_with_gpus() {
sycl::context ctx = sycl::context(devices); sycl::context ctx = sycl::context(devices);
assert(gpus.size() > 0); assert(gpus.size() > 0);
@ -3422,7 +3518,7 @@ class sycl_gpu_mgr {
gpus_list += std::to_string(gpus[i]); gpus_list += std::to_string(gpus[i]);
gpus_list += ","; gpus_list += ",";
} }
if (gpus_list.length() > 2) { if (gpus_list.length() > 1) {
gpus_list.pop_back(); gpus_list.pop_back();
} }
} }
@ -3451,7 +3547,7 @@ class sycl_gpu_mgr {
dpct::device_info prop; dpct::device_info prop;
dpct::get_device_info(prop, device); dpct::get_device_info(prop, device);
if (max_compute_units == prop.get_max_compute_units() && if (max_compute_units == prop.get_max_compute_units() &&
prop.get_major_version() == 1) { is_ext_oneapi_device(device)) {
gpus.push_back(id); gpus.push_back(id);
devices.push_back(device); devices.push_back(device);
work_group_size = prop.get_max_work_group_size(); work_group_size = prop.get_max_work_group_size();
@ -3471,8 +3567,8 @@ class sycl_gpu_mgr {
if (gpus[i] == id) if (gpus[i] == id)
return i; return i;
} }
assert(false); printf("miss to get device index by id=%d\n", id);
return -1; GGML_ASSERT(false);
} }
int get_next_index(int id) { int get_next_index(int id) {
@ -3481,8 +3577,16 @@ class sycl_gpu_mgr {
if (gpus[i] == id) if (gpus[i] == id)
return i; return i;
} }
assert(false); GGML_ASSERT(false);
return -1; }
bool is_ext_oneapi_device(const sycl::device &dev) {
sycl::backend dev_backend = dev.get_backend();
if (dev_backend == sycl::backend::ext_oneapi_level_zero ||
dev_backend == sycl::backend::ext_oneapi_cuda ||
dev_backend == sycl::backend::ext_oneapi_hip)
return true;
return false;
} }
}; };
@ -3491,11 +3595,14 @@ static int g_device_count = -1;
static int g_all_sycl_device_count = -1; static int g_all_sycl_device_count = -1;
static int g_main_device = -1; static int g_main_device = -1;
static int g_main_device_id = -1; static int g_main_device_id = -1;
static bool g_ggml_backend_sycl_buffer_type_initialized = false;
static std::array<float, GGML_SYCL_MAX_DEVICES> g_default_tensor_split = {}; static std::array<float, GGML_SYCL_MAX_DEVICES> g_default_tensor_split = {};
static float g_tensor_split[GGML_SYCL_MAX_DEVICES] = {0}; static float g_tensor_split[GGML_SYCL_MAX_DEVICES] = {0};
static ggml_sycl_backend_gpu_mode g_ggml_sycl_backend_gpu_mode = SYCL_UNSET_GPU_MODE;
struct sycl_device_capabilities { struct sycl_device_capabilities {
int cc; // compute capability int cc; // compute capability
bool vmm; // virtual memory support bool vmm; // virtual memory support
@ -12999,17 +13106,20 @@ bool ggml_sycl_loaded(void) {
return g_sycl_loaded; return g_sycl_loaded;
} }
void print_device_detail(int id) { void print_device_detail(int id, sycl::device &device, std::string device_type) {
dpct::device_info prop; dpct::device_info prop;
SYCL_CHECK(CHECK_TRY_ERROR( SYCL_CHECK(CHECK_TRY_ERROR(
dpct::get_device_info(prop, dpct::dev_mgr::instance().get_device(id)))); dpct::get_device_info(prop, device)));
sycl::device cur_device = dpct::dev_mgr::instance().get_device(id);
std::string version; std::string version;
version += std::to_string(prop.get_major_version()); version += std::to_string(prop.get_major_version());
version += "."; version += ".";
version += std::to_string(prop.get_minor_version()); version += std::to_string(prop.get_minor_version());
fprintf(stderr, "|%2d|%45s|%18s|%17d|%14d|%13d|%15lu|\n", id, device_type = std::regex_replace(device_type, std::regex("ext_oneapi_"), "");
fprintf(stderr, "|%2d|%18s|%45s|%10s|%11d|%8d|%7d|%15lu|\n", id, device_type.c_str(),
prop.get_name(), version.c_str(), prop.get_max_compute_units(), prop.get_name(), version.c_str(), prop.get_max_compute_units(),
prop.get_max_work_group_size(), prop.get_max_sub_group_size(), prop.get_max_work_group_size(), prop.get_max_sub_group_size(),
prop.get_global_mem_size()); prop.get_global_mem_size());
@ -13017,19 +13127,35 @@ void print_device_detail(int id) {
void ggml_backend_sycl_print_sycl_devices() { void ggml_backend_sycl_print_sycl_devices() {
int device_count = dpct::dev_mgr::instance().device_count(); int device_count = dpct::dev_mgr::instance().device_count();
std::map<std::string, size_t> DeviceNums;
fprintf(stderr, "found %d SYCL devices:\n", device_count); fprintf(stderr, "found %d SYCL devices:\n", device_count);
fprintf(stderr, "|ID| Name |compute capability|Max compute units|Max work group|Max sub group|Global mem size|\n"); fprintf(stderr, "| | | |Compute |Max compute|Max work|Max sub| |\n");
fprintf(stderr, "|--|---------------------------------------------|------------------|-----------------|--------------|-------------|---------------|\n"); fprintf(stderr, "|ID| Device Type| Name|capability|units |group |group |Global mem size|\n");
fprintf(stderr, "|--|------------------|---------------------------------------------|----------|-----------|--------|-------|---------------|\n");
for (int id = 0; id < device_count; ++id) { for (int id = 0; id < device_count; ++id) {
print_device_detail(id); sycl::device device = dpct::dev_mgr::instance().get_device(id);
sycl::backend backend = device.get_backend();
std::string backend_type = get_device_backend_and_type(device);
int type_id=DeviceNums[backend_type]++;
std::stringstream device_type;
device_type << "[" << backend_type << ":" << std::to_string(type_id) << "]";
print_device_detail(id, device, device_type.str());
} }
} }
void print_gpu_device_list() { void print_gpu_device_list() {
fprintf(stderr, "detect %d SYCL GPUs: [%s] with Max compute units:%d\n", GGML_ASSERT(g_sycl_gpu_mgr);
g_sycl_gpu_mgr->get_gpu_count(),
g_sycl_gpu_mgr->gpus_list.c_str(), char* hint=NULL;
g_sycl_gpu_mgr->max_compute_units); if (g_ggml_sycl_backend_gpu_mode == SYCL_SINGLE_GPU_MODE) {
hint = "use %d SYCL GPUs: [%s] with Max compute units:%d\n";
} else {
hint = "detect %d SYCL GPUs: [%s] with top Max compute units:%d\n";
}
fprintf(stderr, hint,
g_sycl_gpu_mgr->get_gpu_count(),
g_sycl_gpu_mgr->gpus_list.c_str(),
g_sycl_gpu_mgr->max_compute_units);
} }
int get_sycl_env(const char *env_name, int default_val) { int get_sycl_env(const char *env_name, int default_val) {
@ -13065,23 +13191,6 @@ void ggml_init_sycl() try {
#else #else
fprintf(stderr, "%s: GGML_SYCL_F16: no\n", __func__); fprintf(stderr, "%s: GGML_SYCL_F16: no\n", __func__);
#endif #endif
if (CHECK_TRY_ERROR(g_all_sycl_device_count =
dpct::dev_mgr::instance().device_count()) != 0) {
initialized = true;
g_sycl_loaded = false;
return;
}
GGML_ASSERT(g_all_sycl_device_count <= GGML_SYCL_MAX_DEVICES);
ggml_backend_sycl_print_sycl_devices();
if (!g_sycl_gpu_mgr) g_sycl_gpu_mgr = new sycl_gpu_mgr();
g_device_count = g_sycl_gpu_mgr->get_gpu_count();
g_work_group_size = g_sycl_gpu_mgr->work_group_size;
print_gpu_device_list();
int64_t total_vram = 0;
/* NOT REMOVE, keep it for next optimize for XMX. /* NOT REMOVE, keep it for next optimize for XMX.
#if defined(SYCL_USE_XMX) #if defined(SYCL_USE_XMX)
@ -13090,49 +13199,15 @@ void ggml_init_sycl() try {
fprintf(stderr, "%s: SYCL_USE_XMX: no\n", __func__); fprintf(stderr, "%s: SYCL_USE_XMX: no\n", __func__);
#endif #endif
*/ */
for (int id = 0; id < GGML_SYCL_MAX_DEVICES; ++id) {
g_device_caps[id].vmm = 0; if (CHECK_TRY_ERROR(g_all_sycl_device_count =
g_device_caps[id].device_id = -1; dpct::dev_mgr::instance().device_count()) != 0) {
g_device_caps[id].cc = 0; initialized = true;
g_tensor_split[id] = 0; g_sycl_loaded = false;
g_default_tensor_split[id] = 0; return;
} }
GGML_ASSERT(g_all_sycl_device_count <= GGML_SYCL_MAX_DEVICES);
for (int i = 0; i < g_device_count; ++i) { ggml_backend_sycl_print_sycl_devices();
int device_id = g_sycl_gpu_mgr->gpus[i];
g_device_caps[i].vmm = 0;
dpct::device_info prop;
SYCL_CHECK(CHECK_TRY_ERROR(dpct::get_device_info(
prop, dpct::dev_mgr::instance().get_device(device_id))));
g_default_tensor_split[i] = total_vram;
total_vram += prop.get_global_mem_size();
g_device_caps[i].cc =
100 * prop.get_major_version() + 10 * prop.get_minor_version();
}
for (int i = 0; i < g_device_count; ++i) {
g_default_tensor_split[i] /= total_vram;
}
for (int i = 0; i < g_device_count; ++i) {
SYCL_CHECK(ggml_sycl_set_device(i));
// create sycl streams
for (int is = 0; is < MAX_STREAMS; ++is) {
SYCL_CHECK(CHECK_TRY_ERROR(
g_syclStreams[i][is] =
dpct::get_current_device().create_queue(
g_sycl_gpu_mgr->get_co_ctx(), dpct::get_current_device())));
}
const dpct::queue_ptr stream = g_syclStreams[i][0];
// create sycl handle
SYCL_CHECK(CHECK_TRY_ERROR(g_sycl_handles[i] = stream));
}
initialized = true; initialized = true;
g_sycl_loaded = true; g_sycl_loaded = true;
} }
@ -13143,6 +13218,63 @@ catch (sycl::exception const &exc) {
std::exit(1); std::exit(1);
} }
void ggml_init_by_gpus(int device_count) try {
g_device_count = device_count;
g_work_group_size = g_sycl_gpu_mgr->work_group_size;
int64_t total_vram = 0;
print_gpu_device_list();
for (int id = 0; id < GGML_SYCL_MAX_DEVICES; ++id) {
g_device_caps[id].vmm = 0;
g_device_caps[id].device_id = -1;
g_device_caps[id].cc = 0;
g_tensor_split[id] = 0;
g_default_tensor_split[id] = 0;
}
for (int i = 0; i < g_device_count; ++i) {
int device_id = g_sycl_gpu_mgr->gpus[i];
g_device_caps[i].vmm = 0;
dpct::device_info prop;
SYCL_CHECK(CHECK_TRY_ERROR(dpct::get_device_info(
prop, dpct::dev_mgr::instance().get_device(device_id))));
g_default_tensor_split[i] = total_vram;
total_vram += prop.get_global_mem_size();
g_device_caps[i].cc =
100 * prop.get_major_version() + 10 * prop.get_minor_version();
}
for (int i = 0; i < g_device_count; ++i) {
g_default_tensor_split[i] /= total_vram;
}
for (int i = 0; i < g_device_count; ++i) {
SYCL_CHECK(ggml_sycl_set_device(i));
// create sycl streams
for (int is = 0; is < MAX_STREAMS; ++is) {
SYCL_CHECK(CHECK_TRY_ERROR(
g_syclStreams[i][is] =
dpct::get_current_device().create_queue(
g_sycl_gpu_mgr->get_co_ctx(), dpct::get_current_device())));
}
const dpct::queue_ptr stream = g_syclStreams[i][0];
// create sycl handle
SYCL_CHECK(CHECK_TRY_ERROR(g_sycl_handles[i] = stream));
}
}
catch (sycl::exception const &exc) {
std::cerr << exc.what() << "Exception caught at file:" << __FILE__
<< ", line:" << __LINE__ << std::endl;
std::exit(1);
}
void *ggml_sycl_host_malloc(size_t size) try { void *ggml_sycl_host_malloc(size_t size) try {
if (getenv("GGML_SYCL_NO_PINNED") != nullptr) { if (getenv("GGML_SYCL_NO_PINNED") != nullptr) {
return nullptr; return nullptr;
@ -16542,22 +16674,24 @@ static ggml_backend_buffer_type_i ggml_backend_sycl_buffer_type_interface = {
/* .is_host = */ nullptr, /* .is_host = */ nullptr,
}; };
ggml_backend_buffer_type_t ggml_backend_sycl_buffer_type(int device) { ggml_backend_buffer_type_t ggml_backend_sycl_buffer_type(int device_index) {
if (device_index>=g_device_count or device_index<0) {
printf("ggml_backend_sycl_buffer_type error: device_index:%d is out of range [0, %d], miss to call ggml_backend_sycl_set_single_device()\n",
device_index, g_device_count-1);
GGML_ASSERT(device_index<g_device_count);
}
static struct ggml_backend_buffer_type ggml_backend_sycl_buffer_types[GGML_SYCL_MAX_DEVICES]; static struct ggml_backend_buffer_type ggml_backend_sycl_buffer_types[GGML_SYCL_MAX_DEVICES];
static bool ggml_backend_sycl_buffer_type_initialized = false; if (!g_ggml_backend_sycl_buffer_type_initialized) {
if (!ggml_backend_sycl_buffer_type_initialized) {
for (int i = 0; i < g_device_count; i++) { for (int i = 0; i < g_device_count; i++) {
ggml_backend_sycl_buffer_types[i] = { ggml_backend_sycl_buffer_types[i] = {
/* .iface = */ ggml_backend_sycl_buffer_type_interface, /* .iface = */ ggml_backend_sycl_buffer_type_interface,
/* .context = */ new ggml_backend_sycl_buffer_type_context{i, GGML_SYCL_NAME + std::to_string(g_sycl_gpu_mgr->gpus[i])}, /* .context = */ new ggml_backend_sycl_buffer_type_context{i, GGML_SYCL_NAME + std::to_string(g_sycl_gpu_mgr->gpus[i])},
}; };
} }
ggml_backend_sycl_buffer_type_initialized = true; g_ggml_backend_sycl_buffer_type_initialized = true;
} }
return &ggml_backend_sycl_buffer_types[device_index];
return &ggml_backend_sycl_buffer_types[device];
} }
// sycl split buffer type // sycl split buffer type
@ -17310,11 +17444,42 @@ GGML_API GGML_CALL int ggml_backend_sycl_get_device_index(int device_id) {
return g_sycl_gpu_mgr->get_index(device_id); return g_sycl_gpu_mgr->get_index(device_id);
} }
GGML_API GGML_CALL int ggml_backend_sycl_get_device_id(int device_index) {
return g_sycl_gpu_mgr->gpus[device_index];
}
GGML_API GGML_CALL void ggml_backend_sycl_set_single_device_mode(int main_gpu_id) {
GGML_ASSERT(main_gpu_id<g_all_sycl_device_count);
fprintf(stderr, "ggml_backend_sycl_set_single_device: use single device: [%d]\n", main_gpu_id);
if (g_sycl_gpu_mgr) {
delete g_sycl_gpu_mgr;
}
g_sycl_gpu_mgr = new sycl_gpu_mgr(main_gpu_id);
g_ggml_sycl_backend_gpu_mode = SYCL_SINGLE_GPU_MODE;
ggml_init_by_gpus(g_sycl_gpu_mgr->get_gpu_count());
g_ggml_backend_sycl_buffer_type_initialized = false;
}
GGML_API GGML_CALL void ggml_backend_sycl_set_mul_device_mode() {
if (g_ggml_sycl_backend_gpu_mode == SYCL_MUL_GPU_MODE) {
return;
}
fprintf(stderr, "ggml_backend_sycl_set_mul_device_mode: true\n");
if (g_sycl_gpu_mgr) {
delete g_sycl_gpu_mgr;
}
g_sycl_gpu_mgr = new sycl_gpu_mgr();
g_ggml_sycl_backend_gpu_mode = SYCL_MUL_GPU_MODE;
ggml_init_by_gpus(g_sycl_gpu_mgr->get_gpu_count());
g_ggml_backend_sycl_buffer_type_initialized = false;
}
extern "C" int ggml_backend_sycl_reg_devices(); extern "C" int ggml_backend_sycl_reg_devices();
int ggml_backend_sycl_reg_devices() { int ggml_backend_sycl_reg_devices() {
if (!g_sycl_gpu_mgr) g_sycl_gpu_mgr = new sycl_gpu_mgr(); ggml_backend_sycl_set_mul_device_mode();
g_device_count = g_sycl_gpu_mgr->get_gpu_count();
assert(g_device_count>0); assert(g_device_count>0);
for (int i = 0; i < g_device_count; i++) { for (int i = 0; i < g_device_count; i++) {
int id = g_sycl_gpu_mgr->gpus[i]; int id = g_sycl_gpu_mgr->gpus[i];

View file

@ -29,6 +29,11 @@ GGML_API GGML_CALL ggml_backend_buffer_type_t ggml_backend_sycl_split_buffer_typ
GGML_API GGML_CALL void ggml_backend_sycl_get_device_memory(int device, size_t *free, size_t *total); GGML_API GGML_CALL void ggml_backend_sycl_get_device_memory(int device, size_t *free, size_t *total);
GGML_API GGML_CALL int ggml_backend_sycl_get_device_index(int device_id); GGML_API GGML_CALL int ggml_backend_sycl_get_device_index(int device_id);
// TODO: these are temporary
// ref: https://github.com/ggerganov/llama.cpp/pull/6022#issuecomment-1992615670
GGML_API GGML_CALL int ggml_backend_sycl_get_device_id(int device_index);
GGML_API GGML_CALL void ggml_backend_sycl_set_single_device_mode(int main_gpu_id);
GGML_API GGML_CALL void ggml_backend_sycl_set_mul_device_mode();
#ifdef __cplusplus #ifdef __cplusplus
} }
#endif #endif

17
ggml.c
View file

@ -470,6 +470,19 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
.type_size = sizeof(int32_t), .type_size = sizeof(int32_t),
.is_quantized = false, .is_quantized = false,
}, },
[GGML_TYPE_I64] = {
.type_name = "i64",
.blck_size = 1,
.type_size = sizeof(int64_t),
.is_quantized = false,
},
[GGML_TYPE_F64] = {
.type_name = "f64",
.blck_size = 1,
.type_size = sizeof(double),
.is_quantized = false,
.nrows = 1,
},
[GGML_TYPE_F32] = { [GGML_TYPE_F32] = {
.type_name = "f32", .type_name = "f32",
.blck_size = 1, .blck_size = 1,
@ -12418,6 +12431,8 @@ static void ggml_compute_forward_alibi(
case GGML_TYPE_I8: case GGML_TYPE_I8:
case GGML_TYPE_I16: case GGML_TYPE_I16:
case GGML_TYPE_I32: case GGML_TYPE_I32:
case GGML_TYPE_I64:
case GGML_TYPE_F64:
case GGML_TYPE_COUNT: case GGML_TYPE_COUNT:
{ {
GGML_ASSERT(false); GGML_ASSERT(false);
@ -12504,6 +12519,8 @@ static void ggml_compute_forward_clamp(
case GGML_TYPE_I8: case GGML_TYPE_I8:
case GGML_TYPE_I16: case GGML_TYPE_I16:
case GGML_TYPE_I32: case GGML_TYPE_I32:
case GGML_TYPE_I64:
case GGML_TYPE_F64:
case GGML_TYPE_COUNT: case GGML_TYPE_COUNT:
{ {
GGML_ASSERT(false); GGML_ASSERT(false);

2
ggml.h
View file

@ -366,6 +366,8 @@ extern "C" {
GGML_TYPE_I8 = 24, GGML_TYPE_I8 = 24,
GGML_TYPE_I16 = 25, GGML_TYPE_I16 = 25,
GGML_TYPE_I32 = 26, GGML_TYPE_I32 = 26,
GGML_TYPE_I64 = 27,
GGML_TYPE_F64 = 28,
GGML_TYPE_COUNT, GGML_TYPE_COUNT,
}; };

View file

@ -42,6 +42,7 @@ class Keys:
EXPERT_COUNT = "{arch}.expert_count" EXPERT_COUNT = "{arch}.expert_count"
EXPERT_USED_COUNT = "{arch}.expert_used_count" EXPERT_USED_COUNT = "{arch}.expert_used_count"
POOLING_TYPE = "{arch}.pooling_type" POOLING_TYPE = "{arch}.pooling_type"
LOGIT_SCALE = "{arch}.logit_scale"
class Attention: class Attention:
HEAD_COUNT = "{arch}.attention.head_count" HEAD_COUNT = "{arch}.attention.head_count"
@ -121,6 +122,7 @@ class MODEL_ARCH(IntEnum):
GEMMA = auto() GEMMA = auto()
STARCODER2 = auto() STARCODER2 = auto()
MAMBA = auto() MAMBA = auto()
COMMAND_R = auto()
class MODEL_TENSOR(IntEnum): class MODEL_TENSOR(IntEnum):
@ -187,6 +189,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
MODEL_ARCH.GEMMA: "gemma", MODEL_ARCH.GEMMA: "gemma",
MODEL_ARCH.STARCODER2: "starcoder2", MODEL_ARCH.STARCODER2: "starcoder2",
MODEL_ARCH.MAMBA: "mamba", MODEL_ARCH.MAMBA: "mamba",
MODEL_ARCH.COMMAND_R: "command-r",
} }
TENSOR_NAMES: dict[MODEL_TENSOR, str] = { TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
@ -579,6 +582,18 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
MODEL_TENSOR.SSM_D, MODEL_TENSOR.SSM_D,
MODEL_TENSOR.SSM_OUT, MODEL_TENSOR.SSM_OUT,
], ],
MODEL_ARCH.COMMAND_R: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT_NORM,
MODEL_TENSOR.ATTN_NORM,
MODEL_TENSOR.ATTN_Q,
MODEL_TENSOR.ATTN_K,
MODEL_TENSOR.ATTN_V,
MODEL_TENSOR.ATTN_OUT,
MODEL_TENSOR.FFN_GATE,
MODEL_TENSOR.FFN_DOWN,
MODEL_TENSOR.FFN_UP,
],
# TODO # TODO
} }
@ -665,6 +680,8 @@ class GGMLQuantizationType(IntEnum):
I8 = 24 I8 = 24
I16 = 25 I16 = 25
I32 = 26 I32 = 26
I64 = 27
F64 = 28
class GGUFEndian(IntEnum): class GGUFEndian(IntEnum):
@ -734,6 +751,8 @@ GGML_QUANT_SIZES = {
GGMLQuantizationType.I8: (1, 1), GGMLQuantizationType.I8: (1, 1),
GGMLQuantizationType.I16: (1, 2), GGMLQuantizationType.I16: (1, 2),
GGMLQuantizationType.I32: (1, 4), GGMLQuantizationType.I32: (1, 4),
GGMLQuantizationType.I64: (1, 8),
GGMLQuantizationType.F64: (1, 8),
} }

View file

@ -242,12 +242,15 @@ class GGUFReader:
n_bytes = n_elems * type_size // block_size n_bytes = n_elems * type_size // block_size
data_offs = int(start_offs + offset_tensor[0]) data_offs = int(start_offs + offset_tensor[0])
item_type: npt.DTypeLike item_type: npt.DTypeLike
if ggml_type == GGMLQuantizationType.F32: if ggml_type == GGMLQuantizationType.F16:
item_count = n_elems
item_type = np.float32
elif ggml_type == GGMLQuantizationType.F16:
item_count = n_elems item_count = n_elems
item_type = np.float16 item_type = np.float16
elif ggml_type == GGMLQuantizationType.F32:
item_count = n_elems
item_type = np.float32
elif ggml_type == GGMLQuantizationType.F64:
item_count = n_elems
item_type = np.float64
elif ggml_type == GGMLQuantizationType.I8: elif ggml_type == GGMLQuantizationType.I8:
item_count = n_elems item_count = n_elems
item_type = np.int8 item_type = np.int8
@ -257,6 +260,9 @@ class GGUFReader:
elif ggml_type == GGMLQuantizationType.I32: elif ggml_type == GGMLQuantizationType.I32:
item_count = n_elems item_count = n_elems
item_type = np.int32 item_type = np.int32
elif ggml_type == GGMLQuantizationType.I64:
item_count = n_elems
item_type = np.int64
else: else:
item_count = n_bytes item_count = n_bytes
item_type = np.uint8 item_type = np.uint8

View file

@ -204,18 +204,22 @@ class GGUFWriter:
for i in range(n_dims): for i in range(n_dims):
self.ti_data += self._pack("Q", tensor_shape[n_dims - 1 - i]) self.ti_data += self._pack("Q", tensor_shape[n_dims - 1 - i])
if raw_dtype is None: if raw_dtype is None:
if tensor_dtype == np.float32: if tensor_dtype == np.float16:
dtype = GGMLQuantizationType.F32
elif tensor_dtype == np.float16:
dtype = GGMLQuantizationType.F16 dtype = GGMLQuantizationType.F16
elif tensor_dtype == np.float32:
dtype = GGMLQuantizationType.F32
elif tensor_dtype == np.float64:
dtype = GGMLQuantizationType.F64
elif tensor_dtype == np.int8: elif tensor_dtype == np.int8:
dtype = GGMLQuantizationType.I8 dtype = GGMLQuantizationType.I8
elif tensor_dtype == np.int16: elif tensor_dtype == np.int16:
dtype = GGMLQuantizationType.I16 dtype = GGMLQuantizationType.I16
elif tensor_dtype == np.int32: elif tensor_dtype == np.int32:
dtype = GGMLQuantizationType.I32 dtype = GGMLQuantizationType.I32
elif tensor_dtype == np.int64:
dtype = GGMLQuantizationType.I64
else: else:
raise ValueError("Only F32, F16, I8, I16, I32 tensors are supported for now") raise ValueError("Only F16, F32, F64, I8, I16, I32, I64 tensors are supported for now")
else: else:
dtype = raw_dtype dtype = raw_dtype
self.ti_data += self._pack("I", dtype) self.ti_data += self._pack("I", dtype)
@ -357,6 +361,9 @@ class GGUFWriter:
def add_clamp_kqv(self, value: float) -> None: def add_clamp_kqv(self, value: float) -> None:
self.add_float32(Keys.Attention.CLAMP_KQV.format(arch=self.arch), value) self.add_float32(Keys.Attention.CLAMP_KQV.format(arch=self.arch), value)
def add_logit_scale(self, value: float) -> None:
self.add_float32(Keys.LLM.LOGIT_SCALE.format(arch=self.arch), value)
def add_expert_count(self, count: int) -> None: def add_expert_count(self, count: int) -> None:
self.add_uint32(Keys.LLM.EXPERT_COUNT.format(arch=self.arch), count) self.add_uint32(Keys.LLM.EXPERT_COUNT.format(arch=self.arch), count)

357
llama.cpp
View file

@ -214,6 +214,7 @@ enum llm_arch {
LLM_ARCH_GEMMA, LLM_ARCH_GEMMA,
LLM_ARCH_STARCODER2, LLM_ARCH_STARCODER2,
LLM_ARCH_MAMBA, LLM_ARCH_MAMBA,
LLM_ARCH_COMMAND_R,
LLM_ARCH_UNKNOWN, LLM_ARCH_UNKNOWN,
}; };
@ -243,6 +244,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
{ LLM_ARCH_GEMMA, "gemma" }, { LLM_ARCH_GEMMA, "gemma" },
{ LLM_ARCH_STARCODER2, "starcoder2" }, { LLM_ARCH_STARCODER2, "starcoder2" },
{ LLM_ARCH_MAMBA, "mamba" }, { LLM_ARCH_MAMBA, "mamba" },
{ LLM_ARCH_COMMAND_R, "command-r" },
{ LLM_ARCH_UNKNOWN, "(unknown)" }, { LLM_ARCH_UNKNOWN, "(unknown)" },
}; };
@ -268,6 +270,7 @@ enum llm_kv {
LLM_KV_EXPERT_COUNT, LLM_KV_EXPERT_COUNT,
LLM_KV_EXPERT_USED_COUNT, LLM_KV_EXPERT_USED_COUNT,
LLM_KV_POOLING_TYPE, LLM_KV_POOLING_TYPE,
LLM_KV_LOGIT_SCALE,
LLM_KV_ATTENTION_HEAD_COUNT, LLM_KV_ATTENTION_HEAD_COUNT,
LLM_KV_ATTENTION_HEAD_COUNT_KV, LLM_KV_ATTENTION_HEAD_COUNT_KV,
@ -332,6 +335,7 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
{ LLM_KV_EXPERT_COUNT, "%s.expert_count" }, { LLM_KV_EXPERT_COUNT, "%s.expert_count" },
{ LLM_KV_EXPERT_USED_COUNT, "%s.expert_used_count" }, { LLM_KV_EXPERT_USED_COUNT, "%s.expert_used_count" },
{ LLM_KV_POOLING_TYPE , "%s.pooling_type" }, { LLM_KV_POOLING_TYPE , "%s.pooling_type" },
{ LLM_KV_LOGIT_SCALE, "%s.logit_scale" },
{ LLM_KV_ATTENTION_HEAD_COUNT, "%s.attention.head_count" }, { LLM_KV_ATTENTION_HEAD_COUNT, "%s.attention.head_count" },
{ LLM_KV_ATTENTION_HEAD_COUNT_KV, "%s.attention.head_count_kv" }, { LLM_KV_ATTENTION_HEAD_COUNT_KV, "%s.attention.head_count_kv" },
@ -838,6 +842,21 @@ static const std::map<llm_arch, std::map<llm_tensor, std::string>> LLM_TENSOR_NA
{ LLM_TENSOR_SSM_OUT, "blk.%d.ssm_out" }, { LLM_TENSOR_SSM_OUT, "blk.%d.ssm_out" },
}, },
}, },
{
LLM_ARCH_COMMAND_R,
{
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
},
},
{ {
LLM_ARCH_UNKNOWN, LLM_ARCH_UNKNOWN,
{ {
@ -1597,6 +1616,7 @@ enum e_model {
MODEL_20B, MODEL_20B,
MODEL_30B, MODEL_30B,
MODEL_34B, MODEL_34B,
MODEL_35B,
MODEL_40B, MODEL_40B,
MODEL_65B, MODEL_65B,
MODEL_70B, MODEL_70B,
@ -1643,6 +1663,7 @@ struct llama_hparams {
float f_clamp_kqv = 0.0f; float f_clamp_kqv = 0.0f;
float f_max_alibi_bias = 0.0f; float f_max_alibi_bias = 0.0f;
float f_logit_scale = 0.0f;
bool causal_attn = true; bool causal_attn = true;
bool need_kq_pos = false; bool need_kq_pos = false;
@ -1873,6 +1894,31 @@ struct llama_kv_cache {
} }
}; };
struct llama_control_vector {
std::vector<struct ggml_tensor *> tensors; // per layer
std::vector<struct ggml_context *> ctxs;
std::vector<ggml_backend_buffer_t> bufs;
int32_t layer_start = -1;
int32_t layer_end = -1;
ggml_tensor * tensor_for(int il) const {
if (il < 0 || il < layer_start || il > layer_end || (size_t) il >= tensors.size()) {
return nullptr;
}
return tensors[il];
}
~llama_control_vector() {
for (struct ggml_context * ctx : ctxs) {
ggml_free(ctx);
}
for (ggml_backend_buffer_t buf : bufs) {
ggml_backend_buffer_free(buf);
}
}
};
struct llama_vocab { struct llama_vocab {
using id = int32_t; using id = int32_t;
using token = std::string; using token = std::string;
@ -2087,6 +2133,9 @@ struct llama_context {
struct ggml_tensor * inp_s_mask; // F32 [1, kv_size] struct ggml_tensor * inp_s_mask; // F32 [1, kv_size]
struct ggml_tensor * inp_s_seq; // I32 [kv_size, n_batch] struct ggml_tensor * inp_s_seq; // I32 [kv_size, n_batch]
// control vectors
struct llama_control_vector cvec;
#ifdef GGML_USE_MPI #ifdef GGML_USE_MPI
ggml_mpi_context * ctx_mpi = NULL; ggml_mpi_context * ctx_mpi = NULL;
#endif #endif
@ -3231,6 +3280,7 @@ static const char * llama_model_type_name(e_model type) {
case MODEL_20B: return "20B"; case MODEL_20B: return "20B";
case MODEL_30B: return "30B"; case MODEL_30B: return "30B";
case MODEL_34B: return "34B"; case MODEL_34B: return "34B";
case MODEL_35B: return "35B";
case MODEL_40B: return "40B"; case MODEL_40B: return "40B";
case MODEL_65B: return "65B"; case MODEL_65B: return "65B";
case MODEL_70B: return "70B"; case MODEL_70B: return "70B";
@ -3623,6 +3673,15 @@ static void llm_load_hparams(
default: model.type = e_model::MODEL_UNKNOWN; default: model.type = e_model::MODEL_UNKNOWN;
} }
} break; } break;
case LLM_ARCH_COMMAND_R:
{
ml.get_key(LLM_KV_LOGIT_SCALE, hparams.f_logit_scale);
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
switch (hparams.n_layer) {
case 40: model.type = e_model::MODEL_35B; break;
default: model.type = e_model::MODEL_UNKNOWN;
}
} break;
default: (void)0; default: (void)0;
} }
@ -3944,6 +4003,7 @@ static void llm_load_print_meta(llama_model_loader & ml, llama_model & model) {
LLAMA_LOG_INFO("%s: f_norm_rms_eps = %.1e\n", __func__, hparams.f_norm_rms_eps); LLAMA_LOG_INFO("%s: f_norm_rms_eps = %.1e\n", __func__, hparams.f_norm_rms_eps);
LLAMA_LOG_INFO("%s: f_clamp_kqv = %.1e\n", __func__, hparams.f_clamp_kqv); LLAMA_LOG_INFO("%s: f_clamp_kqv = %.1e\n", __func__, hparams.f_clamp_kqv);
LLAMA_LOG_INFO("%s: f_max_alibi_bias = %.1e\n", __func__, hparams.f_max_alibi_bias); LLAMA_LOG_INFO("%s: f_max_alibi_bias = %.1e\n", __func__, hparams.f_max_alibi_bias);
LLAMA_LOG_INFO("%s: f_logit_scale = %.1e\n", __func__, hparams.f_logit_scale);
LLAMA_LOG_INFO("%s: n_ff = %u\n", __func__, hparams.n_ff); LLAMA_LOG_INFO("%s: n_ff = %u\n", __func__, hparams.n_ff);
LLAMA_LOG_INFO("%s: n_expert = %u\n", __func__, hparams.n_expert); LLAMA_LOG_INFO("%s: n_expert = %u\n", __func__, hparams.n_expert);
LLAMA_LOG_INFO("%s: n_expert_used = %u\n", __func__, hparams.n_expert_used); LLAMA_LOG_INFO("%s: n_expert_used = %u\n", __func__, hparams.n_expert_used);
@ -4918,6 +4978,37 @@ static bool llm_load_tensors(
layer.ssm_out = ml.create_tensor(ctx_split, tn(LLM_TENSOR_SSM_OUT, "weight", i), {d_inner, n_embd}); layer.ssm_out = ml.create_tensor(ctx_split, tn(LLM_TENSOR_SSM_OUT, "weight", i), {d_inner, n_embd});
} }
} break; } break;
case LLM_ARCH_COMMAND_R:
{
model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
// output
{
model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd});
// init output from the input tok embed
model.output = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
ml.n_created--; // artificial tensor
ml.size_data += ggml_nbytes(model.output);
}
for (int i = 0; i < n_layer; ++i) {
ggml_context * ctx_layer = ctx_for_layer(i);
ggml_context * ctx_split = ctx_for_layer_split(i);
auto & layer = model.layers[i];
layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd});
layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd});
layer.wk = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa});
layer.wv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa});
layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd});
layer.ffn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff});
layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd});
layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff});
}
} break;
default: default:
throw std::runtime_error("unknown architecture"); throw std::runtime_error("unknown architecture");
} }
@ -5064,6 +5155,16 @@ static int llama_model_load(const std::string & fname, llama_model & model, llam
} }
#endif #endif
#ifdef GGML_USE_SYCL
if (params.split_mode == LLAMA_SPLIT_MODE_NONE) {
ggml_backend_sycl_set_single_device_mode(params.main_gpu);
//SYCL use device index (0, 1, 2) directly, uer input device id, then convert to device index.
params.main_gpu = ggml_backend_sycl_get_device_index(params.main_gpu);
} else {
ggml_backend_sycl_set_mul_device_mode();
}
#endif
if (!llm_load_tensors( if (!llm_load_tensors(
ml, model, params.n_gpu_layers, params.split_mode, params.main_gpu, params.tensor_split, params.use_mlock, ml, model, params.n_gpu_layers, params.split_mode, params.main_gpu, params.tensor_split, params.use_mlock,
params.progress_callback, params.progress_callback_user_data params.progress_callback, params.progress_callback_user_data
@ -5858,6 +5959,12 @@ struct llm_build_context {
} }
cur = ggml_add(ctx0, cur, ffn_inp); cur = ggml_add(ctx0, cur, ffn_inp);
cb(cur, "ffn_out", il);
ggml_tensor * layer_dir = lctx.cvec.tensor_for(il);
if (layer_dir != nullptr) {
cur = ggml_add(ctx0, cur, layer_dir);
}
cb(cur, "l_out", il); cb(cur, "l_out", il);
// input for next layer // input for next layer
@ -5893,7 +6000,7 @@ struct llm_build_context {
inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb); inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
// inp_pos - contains the positions // inp_pos - contains the positions
struct ggml_tensor * inp_pos = build_inp_pos(); struct ggml_tensor * inp_pos = model.type == MODEL_7B ? build_inp_pos() : nullptr;
// KQ_mask (mask for 1 head, it will be broadcasted to all heads) // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
struct ggml_tensor * KQ_mask = build_inp_KQ_mask(); struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
@ -5943,7 +6050,6 @@ struct llm_build_context {
cb(Qcur, "Qcur", il); cb(Qcur, "Qcur", il);
cb(Kcur, "Kcur", il); cb(Kcur, "Kcur", il);
cur = llm_build_kv(ctx0, model, hparams, kv_self, gf, cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
model.layers[il].wo, NULL, model.layers[il].wo, NULL,
Kcur, Vcur, Qcur, KQ_mask, KQ_pos, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); Kcur, Vcur, Qcur, KQ_mask, KQ_pos, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
@ -8305,6 +8411,121 @@ struct llm_build_context {
return gf; return gf;
} }
struct ggml_cgraph * build_command_r() {
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
const int64_t n_embd_head = hparams.n_embd_head_v;
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
const float f_logit_scale = hparams.f_logit_scale;
struct ggml_tensor * cur;
struct ggml_tensor * inpL;
inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
// inp_pos - contains the positions
struct ggml_tensor * inp_pos = build_inp_pos();
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
for (int il = 0; il < n_layer; ++il) {
// norm
cur = llm_build_norm(ctx0, inpL, hparams,
model.layers[il].attn_norm, NULL,
LLM_NORM, cb, il);
cb(cur, "attn_norm", il);
struct ggml_tensor * ffn_inp = cur;
// self-attention
{
// compute Q and K and RoPE them
struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, model.layers[il].wq, cur);
cb(Qcur, "Qcur", il);
if (model.layers[il].bq) {
Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
cb(Qcur, "Qcur", il);
}
struct ggml_tensor * Kcur = ggml_mul_mat(ctx0, model.layers[il].wk, cur);
cb(Kcur, "Kcur", il);
if (model.layers[il].bk) {
Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
cb(Kcur, "Kcur", il);
}
struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, model.layers[il].wv, cur);
cb(Vcur, "Vcur", il);
if (model.layers[il].bv) {
Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
cb(Vcur, "Vcur", il);
}
Qcur = ggml_rope_custom(
ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos,
n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow
);
cb(Qcur, "Qcur", il);
Kcur = ggml_rope_custom(
ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos,
n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow
);
cb(Kcur, "Kcur", il);
cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
model.layers[il].wo, model.layers[il].bo,
Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
}
struct ggml_tensor * attn_out = cur;
// feed-forward network
{
cur = llm_build_ffn(ctx0, ffn_inp,
model.layers[il].ffn_up, NULL,
model.layers[il].ffn_gate, NULL,
model.layers[il].ffn_down, NULL,
NULL,
LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
cb(cur, "ffn_out", il);
}
// add together residual + FFN + self-attention
cur = ggml_add(ctx0, cur, inpL);
cur = ggml_add(ctx0, cur, attn_out);
cb(cur, "l_out", il);
// input for next layer
inpL = cur;
}
cur = inpL;
cur = llm_build_norm(ctx0, cur, hparams,
model.output_norm, NULL,
LLM_NORM, cb, -1);
cb(cur, "result_norm", -1);
// lm_head
cur = ggml_mul_mat(ctx0, model.output, cur);
if (f_logit_scale) {
cur = ggml_scale(ctx0, cur, f_logit_scale);
}
cb(cur, "result_output", -1);
ggml_build_forward_expand(gf, cur);
return gf;
}
}; };
static struct ggml_cgraph * llama_build_graph_defrag(llama_context & lctx, const std::vector<uint32_t> & ids) { static struct ggml_cgraph * llama_build_graph_defrag(llama_context & lctx, const std::vector<uint32_t> & ids) {
@ -8487,6 +8708,10 @@ static struct ggml_cgraph * llama_build_graph(
{ {
result = llm.build_mamba(); result = llm.build_mamba();
} break; } break;
case LLM_ARCH_COMMAND_R:
{
result = llm.build_command_r();
} break;
default: default:
GGML_ASSERT(false); GGML_ASSERT(false);
} }
@ -12921,23 +13146,22 @@ struct llama_context * llama_new_context_with_model(
if (model->n_gpu_layers > 0) { if (model->n_gpu_layers > 0) {
// with split_mode LLAMA_SPLIT_MODE_NONE or LLAMA_SPLIT_MODE_ROW, only the main GPU backend is used // with split_mode LLAMA_SPLIT_MODE_NONE or LLAMA_SPLIT_MODE_ROW, only the main GPU backend is used
if (model->split_mode == LLAMA_SPLIT_MODE_NONE || model->split_mode == LLAMA_SPLIT_MODE_ROW) { if (model->split_mode == LLAMA_SPLIT_MODE_NONE || model->split_mode == LLAMA_SPLIT_MODE_ROW) {
int main_gpu_index = ggml_backend_sycl_get_device_index(model->main_gpu); ggml_backend_t backend = ggml_backend_sycl_init(model->main_gpu);
ggml_backend_t backend = ggml_backend_sycl_init(main_gpu_index);
if (backend == nullptr) { if (backend == nullptr) {
LLAMA_LOG_ERROR("%s: failed to initialize SYCL%d (index %d)backend\n", __func__, model->main_gpu, main_gpu_index); int main_gpu_id = ggml_backend_sycl_get_device_id(model->main_gpu);
LLAMA_LOG_ERROR("%s: failed to initialize SYCL%d (index %d) backend\n", __func__, main_gpu_id, model->main_gpu);
llama_free(ctx); llama_free(ctx);
return nullptr; return nullptr;
} }
ctx->backends.push_back(backend); ctx->backends.push_back(backend);
} else { } else {
// LLAMA_SPLIT_LAYER requires a backend for each GPU // LLAMA_SPLIT_LAYER requires a backend for each GPU
int id_list[GGML_SYCL_MAX_DEVICES];
ggml_sycl_get_gpu_list(id_list, GGML_SYCL_MAX_DEVICES);
for (int i = 0; i < ggml_backend_sycl_get_device_count(); ++i) { for (int i = 0; i < ggml_backend_sycl_get_device_count(); ++i) {
int device_id = id_list[i];
ggml_backend_t backend = ggml_backend_sycl_init(i); ggml_backend_t backend = ggml_backend_sycl_init(i);
if (backend == nullptr) { if (backend == nullptr) {
LLAMA_LOG_ERROR("%s: failed to initialize SYCL%d (index %d)backend\n", __func__, device_id, i); int id_list[GGML_SYCL_MAX_DEVICES];
ggml_sycl_get_gpu_list(id_list, GGML_SYCL_MAX_DEVICES);
LLAMA_LOG_ERROR("%s: failed to initialize SYCL%d (index %d) backend\n", __func__, id_list[i], i);
llama_free(ctx); llama_free(ctx);
return nullptr; return nullptr;
} }
@ -13138,6 +13362,7 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) {
case LLM_ARCH_ORION: case LLM_ARCH_ORION:
case LLM_ARCH_INTERNLM2: case LLM_ARCH_INTERNLM2:
case LLM_ARCH_MINICPM: case LLM_ARCH_MINICPM:
case LLM_ARCH_COMMAND_R:
return LLAMA_ROPE_TYPE_NORM; return LLAMA_ROPE_TYPE_NORM;
// the pairs of head values are offset by n_rot/2 // the pairs of head values are offset by n_rot/2
@ -13174,6 +13399,10 @@ int32_t llama_n_embd(const struct llama_model * model) {
return model->hparams.n_embd; return model->hparams.n_embd;
} }
int32_t llama_n_layer(const struct llama_model * model) {
return model->hparams.n_layer;
}
float llama_rope_freq_scale_train(const struct llama_model * model) { float llama_rope_freq_scale_train(const struct llama_model * model) {
return model->hparams.rope_freq_scale_train; return model->hparams.rope_freq_scale_train;
} }
@ -13273,6 +13502,96 @@ int32_t llama_model_apply_lora_from_file(const struct llama_model * model, const
} }
} }
static bool llama_control_vector_init(struct llama_control_vector & cvec, const llama_model & model) {
GGML_ASSERT(cvec.tensors.empty());
GGML_ASSERT(cvec.ctxs.empty());
GGML_ASSERT(cvec.bufs.empty());
// count layer buffer types
std::map<ggml_backend_buffer_type_t, int> buft_layer_count;
for (int64_t i = 0; i < model.hparams.n_layer; i++) {
buft_layer_count[model.buft_layer[i].buft]++;
}
// allocate contexts
std::map<ggml_backend_buffer_type_t, ggml_context *> ctx_map;
for (auto & it : buft_layer_count) {
int n_layers = it.second;
struct ggml_init_params params = {
/*.mem_size =*/ n_layers * ggml_tensor_overhead(),
/*.mem_buffer =*/ NULL,
/*.no_alloc =*/ true,
};
ggml_context * ctx = ggml_init(params);
if (!ctx) {
LLAMA_LOG_ERROR("%s: failed to allocate context for control vector\n", __func__);
return 1;
}
ctx_map[it.first] = ctx;
}
// make tensors
cvec.tensors.push_back(nullptr); // there's never a tensor for layer 0
for (size_t il = 1; il < model.hparams.n_layer; il++) {
struct ggml_context * ctx = ctx_map.at(model.buft_layer[il].buft);
ggml_tensor * tensor = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, model.hparams.n_embd);
cvec.tensors.push_back(tensor);
}
// allocate tensors / buffers and zero
for (auto it : ctx_map) {
ggml_backend_buffer_type_t buft = it.first;
ggml_context * ctx = it.second;
ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx, buft);
if (!buf) {
LLAMA_LOG_ERROR("%s: failed to allocate buffer for control vector\n", __func__);
return false;
}
ggml_backend_buffer_clear(buf, 0);
cvec.ctxs.push_back(ctx);
cvec.bufs.push_back(buf);
}
return true;
}
int32_t llama_control_vector_apply(struct llama_context * lctx, const float * data, size_t len, int32_t n_embd, int32_t il_start, int32_t il_end) {
const llama_model & model = lctx->model;
llama_control_vector & cvec = lctx->cvec;
if (data == nullptr) {
// disable the current control vector (but leave allocated for later)
cvec.layer_start = -1;
cvec.layer_end = -1;
return 0;
}
if (n_embd != (int) model.hparams.n_embd) {
LLAMA_LOG_ERROR("%s: control vector n_embd does not match model\n", __func__);
return 1;
}
if (cvec.tensors.empty()) {
if (!llama_control_vector_init(cvec, model)) {
return 1;
}
}
cvec.layer_start = il_start;
cvec.layer_end = il_end;
for (size_t il = 1; il < model.hparams.n_layer; il++) {
assert(cvec.tensors[il] != nullptr);
const size_t off = n_embd * (il - 1); // buffer doesn't have data for layer 0, since it's never present
if (off + n_embd <= len) {
ggml_backend_tensor_set(cvec.tensors[il], data + off, 0, n_embd * ggml_element_size(cvec.tensors[il]));
}
}
return 0;
}
struct llama_kv_cache_view llama_kv_cache_view_init(const struct llama_context * ctx, int32_t n_seq_max) { struct llama_kv_cache_view llama_kv_cache_view_init(const struct llama_context * ctx, int32_t n_seq_max) {
struct llama_kv_cache_view result = { struct llama_kv_cache_view result = {
/*.n_cells = */ 0, /*.n_cells = */ 0,
@ -14242,6 +14561,26 @@ static int32_t llama_chat_apply_template_internal(
if (add_ass) { if (add_ass) {
ss << "<start_of_turn>model\n"; ss << "<start_of_turn>model\n";
} }
} else if (tmpl == "orion" || tmpl.find("'\\n\\nAssistant: ' + eos_token") != std::string::npos) {
// OrionStarAI/Orion-14B-Chat
std::string system_prompt = "";
for (auto message : chat) {
std::string role(message->role);
if (role == "system") {
// there is no system message support, we will merge it with user prompt
system_prompt = message->content;
continue;
} else if (role == "user") {
ss << "Human: ";
if (!system_prompt.empty()) {
ss << system_prompt << "\n\n";
system_prompt = "";
}
ss << message->content << "\n\nAssistant: </s>";
} else {
ss << message->content << "</s>";
}
}
} else { } else {
// template not supported // template not supported
return -1; return -1;

23
llama.h
View file

@ -388,6 +388,7 @@ extern "C" {
LLAMA_API int32_t llama_n_vocab (const struct llama_model * model); LLAMA_API int32_t llama_n_vocab (const struct llama_model * model);
LLAMA_API int32_t llama_n_ctx_train(const struct llama_model * model); LLAMA_API int32_t llama_n_ctx_train(const struct llama_model * model);
LLAMA_API int32_t llama_n_embd (const struct llama_model * model); LLAMA_API int32_t llama_n_embd (const struct llama_model * model);
LLAMA_API int32_t llama_n_layer (const struct llama_model * model);
// Get the model's RoPE frequency scaling factor // Get the model's RoPE frequency scaling factor
LLAMA_API float llama_rope_freq_scale_train(const struct llama_model * model); LLAMA_API float llama_rope_freq_scale_train(const struct llama_model * model);
@ -435,10 +436,24 @@ extern "C" {
// Returns 0 on success // Returns 0 on success
LLAMA_API int32_t llama_model_apply_lora_from_file( LLAMA_API int32_t llama_model_apply_lora_from_file(
const struct llama_model * model, const struct llama_model * model,
const char * path_lora, const char * path_lora,
float scale, float scale,
const char * path_base_model, const char * path_base_model,
int32_t n_threads); int32_t n_threads);
// Apply a loaded control vector to a llama_context, or if data is NULL, clear
// the currently loaded vector.
// n_embd should be the size of a single layer's control, and data should point
// to an n_embd x n_layers buffer starting from layer 1.
// il_start and il_end are the layer range the vector should apply to (both inclusive)
// See llama_control_vector_load in common to load a control vector.
LLAMA_API int32_t llama_control_vector_apply(
struct llama_context * lctx,
const float * data,
size_t len,
int32_t n_embd,
int32_t il_start,
int32_t il_end);
// //
// KV cache // KV cache

View file

@ -31,6 +31,8 @@ int main(void) {
"{% for message in messages %}{{bos_token + message['role'] + '\\n' + message['content'] + eos_token + '\\n'}}{% endfor %}{% if add_generation_prompt %}{{ bos_token + 'assistant\\n' }}{% endif %}", "{% for message in messages %}{{bos_token + message['role'] + '\\n' + message['content'] + eos_token + '\\n'}}{% endfor %}{% if add_generation_prompt %}{{ bos_token + 'assistant\\n' }}{% endif %}",
// google/gemma-7b-it // google/gemma-7b-it
"{% if messages[0]['role'] == 'system' %}{{ raise_exception('System role not supported') }}{% endif %}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if (message['role'] == 'assistant') %}{% set role = 'model' %}{% else %}{% set role = message['role'] %}{% endif %}{{ '<start_of_turn>' + role + '\\n' + message['content'] | trim + '<end_of_turn>\\n' }}{% endfor %}{% if add_generation_prompt %}{{'<start_of_turn>model\\n'}}{% endif %}", "{% if messages[0]['role'] == 'system' %}{{ raise_exception('System role not supported') }}{% endif %}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if (message['role'] == 'assistant') %}{% set role = 'model' %}{% else %}{% set role = message['role'] %}{% endif %}{{ '<start_of_turn>' + role + '\\n' + message['content'] | trim + '<end_of_turn>\\n' }}{% endfor %}{% if add_generation_prompt %}{{'<start_of_turn>model\\n'}}{% endif %}",
// OrionStarAI/Orion-14B-Chat
"{% for message in messages %}{% if loop.first %}{{ bos_token }}{% endif %}{% if message['role'] == 'user' %}{{ 'Human: ' + message['content'] + '\\n\\nAssistant: ' + eos_token }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token }}{% endif %}{% endfor %}",
}; };
std::vector<std::string> expected_output = { std::vector<std::string> expected_output = {
// teknium/OpenHermes-2.5-Mistral-7B // teknium/OpenHermes-2.5-Mistral-7B
@ -45,6 +47,8 @@ int main(void) {
"system\nYou are a helpful assistant</s>\n<s>user\nHello</s>\n<s>assistant\nHi there</s>\n<s>user\nWho are you</s>\n<s>assistant\n I am an assistant </s>\n<s>user\nAnother question</s>\n<s>assistant\n", "system\nYou are a helpful assistant</s>\n<s>user\nHello</s>\n<s>assistant\nHi there</s>\n<s>user\nWho are you</s>\n<s>assistant\n I am an assistant </s>\n<s>user\nAnother question</s>\n<s>assistant\n",
// google/gemma-7b-it // google/gemma-7b-it
"<start_of_turn>user\nYou are a helpful assistant\n\nHello<end_of_turn>\n<start_of_turn>model\nHi there<end_of_turn>\n<start_of_turn>user\nWho are you<end_of_turn>\n<start_of_turn>model\nI am an assistant<end_of_turn>\n<start_of_turn>user\nAnother question<end_of_turn>\n<start_of_turn>model\n", "<start_of_turn>user\nYou are a helpful assistant\n\nHello<end_of_turn>\n<start_of_turn>model\nHi there<end_of_turn>\n<start_of_turn>user\nWho are you<end_of_turn>\n<start_of_turn>model\nI am an assistant<end_of_turn>\n<start_of_turn>user\nAnother question<end_of_turn>\n<start_of_turn>model\n",
// OrionStarAI/Orion-14B-Chat
"Human: You are a helpful assistant\n\nHello\n\nAssistant: </s>Hi there</s>Human: Who are you\n\nAssistant: </s> I am an assistant </s>Human: Another question\n\nAssistant: </s>",
}; };
std::vector<char> formatted_chat(1024); std::vector<char> formatted_chat(1024);
int32_t res; int32_t res;