Merge branch 'master' into gg/metal-feature-set
ggml-ci
This commit is contained in:
commit
f81e467a47
30 changed files with 4041 additions and 2336 deletions
|
@ -543,9 +543,8 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
|
||||||
invalid_param = true;
|
invalid_param = true;
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
#ifdef LLAMA_SUPPORTS_GPU_OFFLOAD
|
|
||||||
params.n_gpu_layers = std::stoi(argv[i]);
|
params.n_gpu_layers = std::stoi(argv[i]);
|
||||||
#else
|
#ifndef LLAMA_SUPPORTS_GPU_OFFLOAD
|
||||||
fprintf(stderr, "warning: not compiled with GPU offload support, --n-gpu-layers option will be ignored\n");
|
fprintf(stderr, "warning: not compiled with GPU offload support, --n-gpu-layers option will be ignored\n");
|
||||||
fprintf(stderr, "warning: see main README.md for information on enabling GPU BLAS support\n");
|
fprintf(stderr, "warning: see main README.md for information on enabling GPU BLAS support\n");
|
||||||
#endif
|
#endif
|
||||||
|
@ -554,9 +553,8 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
|
||||||
invalid_param = true;
|
invalid_param = true;
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
#ifdef LLAMA_SUPPORTS_GPU_OFFLOAD
|
|
||||||
params.n_gpu_layers_draft = std::stoi(argv[i]);
|
params.n_gpu_layers_draft = std::stoi(argv[i]);
|
||||||
#else
|
#ifndef LLAMA_SUPPORTS_GPU_OFFLOAD
|
||||||
fprintf(stderr, "warning: not compiled with GPU offload support, --n-gpu-layers-draft option will be ignored\n");
|
fprintf(stderr, "warning: not compiled with GPU offload support, --n-gpu-layers-draft option will be ignored\n");
|
||||||
fprintf(stderr, "warning: see main README.md for information on enabling GPU BLAS support\n");
|
fprintf(stderr, "warning: see main README.md for information on enabling GPU BLAS support\n");
|
||||||
#endif
|
#endif
|
||||||
|
@ -565,25 +563,44 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
|
||||||
invalid_param = true;
|
invalid_param = true;
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
#ifdef GGML_USE_CUBLAS
|
|
||||||
params.main_gpu = std::stoi(argv[i]);
|
params.main_gpu = std::stoi(argv[i]);
|
||||||
#else
|
#ifndef GGML_USE_CUBLAS
|
||||||
fprintf(stderr, "warning: llama.cpp was compiled without cuBLAS. It is not possible to set a main GPU.\n");
|
fprintf(stderr, "warning: llama.cpp was compiled without cuBLAS. Setting the main GPU has no effect.\n");
|
||||||
#endif
|
#endif // GGML_USE_CUBLAS
|
||||||
|
} else if (arg == "--split-mode" || arg == "-sm") {
|
||||||
|
if (++i >= argc) {
|
||||||
|
invalid_param = true;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
std::string arg_next = argv[i];
|
||||||
|
if (arg_next == "none") {
|
||||||
|
params.split_mode = LLAMA_SPLIT_NONE;
|
||||||
|
} else if (arg_next == "layer") {
|
||||||
|
params.split_mode = LLAMA_SPLIT_LAYER;
|
||||||
|
} else if (arg_next == "row") {
|
||||||
|
params.split_mode = LLAMA_SPLIT_ROW;
|
||||||
|
} else {
|
||||||
|
invalid_param = true;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
#ifndef GGML_USE_CUBLAS
|
||||||
|
fprintf(stderr, "warning: llama.cpp was compiled without cuBLAS. Setting the split mode has no effect.\n");
|
||||||
|
#endif // GGML_USE_CUBLAS
|
||||||
} else if (arg == "--tensor-split" || arg == "-ts") {
|
} else if (arg == "--tensor-split" || arg == "-ts") {
|
||||||
if (++i >= argc) {
|
if (++i >= argc) {
|
||||||
invalid_param = true;
|
invalid_param = true;
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
#ifdef GGML_USE_CUBLAS
|
|
||||||
std::string arg_next = argv[i];
|
std::string arg_next = argv[i];
|
||||||
|
|
||||||
// split string by , and /
|
// split string by , and /
|
||||||
const std::regex regex{R"([,/]+)"};
|
const std::regex regex{R"([,/]+)"};
|
||||||
std::sregex_token_iterator it{arg_next.begin(), arg_next.end(), regex, -1};
|
std::sregex_token_iterator it{arg_next.begin(), arg_next.end(), regex, -1};
|
||||||
std::vector<std::string> split_arg{it, {}};
|
std::vector<std::string> split_arg{it, {}};
|
||||||
GGML_ASSERT(split_arg.size() <= LLAMA_MAX_DEVICES);
|
if (split_arg.size() >= LLAMA_MAX_DEVICES) {
|
||||||
|
invalid_param = true;
|
||||||
|
break;
|
||||||
|
}
|
||||||
for (size_t i = 0; i < LLAMA_MAX_DEVICES; ++i) {
|
for (size_t i = 0; i < LLAMA_MAX_DEVICES; ++i) {
|
||||||
if (i < split_arg.size()) {
|
if (i < split_arg.size()) {
|
||||||
params.tensor_split[i] = std::stof(split_arg[i]);
|
params.tensor_split[i] = std::stof(split_arg[i]);
|
||||||
|
@ -591,14 +608,8 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
|
||||||
params.tensor_split[i] = 0.0f;
|
params.tensor_split[i] = 0.0f;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
#else
|
#ifndef GGML_USE_CUBLAS
|
||||||
fprintf(stderr, "warning: llama.cpp was compiled without cuBLAS. It is not possible to set a tensor split.\n");
|
fprintf(stderr, "warning: llama.cpp was compiled without cuBLAS. Setting a tensor split has no effect.\n");
|
||||||
#endif // GGML_USE_CUBLAS
|
|
||||||
} else if (arg == "--no-mul-mat-q" || arg == "-nommq") {
|
|
||||||
#ifdef GGML_USE_CUBLAS
|
|
||||||
params.mul_mat_q = false;
|
|
||||||
#else
|
|
||||||
fprintf(stderr, "warning: llama.cpp was compiled without cuBLAS. Disabling mul_mat_q kernels has no effect.\n");
|
|
||||||
#endif // GGML_USE_CUBLAS
|
#endif // GGML_USE_CUBLAS
|
||||||
} else if (arg == "--no-mmap") {
|
} else if (arg == "--no-mmap") {
|
||||||
params.use_mmap = false;
|
params.use_mmap = false;
|
||||||
|
@ -915,14 +926,15 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
|
||||||
printf(" number of layers to store in VRAM\n");
|
printf(" number of layers to store in VRAM\n");
|
||||||
printf(" -ngld N, --n-gpu-layers-draft N\n");
|
printf(" -ngld N, --n-gpu-layers-draft N\n");
|
||||||
printf(" number of layers to store in VRAM for the draft model\n");
|
printf(" number of layers to store in VRAM for the draft model\n");
|
||||||
|
printf(" -sm SPLIT_MODE, --split-mode SPLIT_MODE\n");
|
||||||
|
printf(" how to split the model across multiple GPUs, one of:\n");
|
||||||
|
printf(" - none: use one GPU only\n");
|
||||||
|
printf(" - layer (default): split layers and KV across GPUs\n");
|
||||||
|
printf(" - row: split rows across GPUs\n");
|
||||||
printf(" -ts SPLIT, --tensor-split SPLIT\n");
|
printf(" -ts SPLIT, --tensor-split SPLIT\n");
|
||||||
printf(" how to split tensors across multiple GPUs, comma-separated list of proportions, e.g. 3,1\n");
|
printf(" fraction of the model to offload to each GPU, comma-separated list of proportions, e.g. 3,1\n");
|
||||||
printf(" -mg i, --main-gpu i the GPU to use for scratch and small tensors\n");
|
printf(" -mg i, --main-gpu i the GPU to use for the model (with split-mode = none),\n");
|
||||||
#ifdef GGML_USE_CUBLAS
|
printf(" or for intermediate results and KV (with split-mode = row) (default: %d)\n", params.main_gpu);
|
||||||
printf(" -nommq, --no-mul-mat-q\n");
|
|
||||||
printf(" use " GGML_CUBLAS_NAME " instead of custom mul_mat_q " GGML_CUDA_NAME " kernels.\n");
|
|
||||||
printf(" Not recommended since this is both slower and uses more VRAM.\n");
|
|
||||||
#endif // GGML_USE_CUBLAS
|
|
||||||
#endif
|
#endif
|
||||||
printf(" -gan N, --grp-attn-n N\n");
|
printf(" -gan N, --grp-attn-n N\n");
|
||||||
printf(" group-attention factor (default: %d)\n", params.grp_attn_n);
|
printf(" group-attention factor (default: %d)\n", params.grp_attn_n);
|
||||||
|
@ -1041,6 +1053,7 @@ struct llama_model_params llama_model_params_from_gpt_params(const gpt_params &
|
||||||
mparams.n_gpu_layers = params.n_gpu_layers;
|
mparams.n_gpu_layers = params.n_gpu_layers;
|
||||||
}
|
}
|
||||||
mparams.main_gpu = params.main_gpu;
|
mparams.main_gpu = params.main_gpu;
|
||||||
|
mparams.split_mode = params.split_mode;
|
||||||
mparams.tensor_split = params.tensor_split;
|
mparams.tensor_split = params.tensor_split;
|
||||||
mparams.use_mmap = params.use_mmap;
|
mparams.use_mmap = params.use_mmap;
|
||||||
mparams.use_mlock = params.use_mlock;
|
mparams.use_mlock = params.use_mlock;
|
||||||
|
|
|
@ -59,6 +59,7 @@ struct gpt_params {
|
||||||
float p_split = 0.1f; // speculative decoding split probability
|
float p_split = 0.1f; // speculative decoding split probability
|
||||||
int32_t n_gpu_layers = -1; // number of layers to store in VRAM (-1 - use default)
|
int32_t n_gpu_layers = -1; // number of layers to store in VRAM (-1 - use default)
|
||||||
int32_t n_gpu_layers_draft = -1; // number of layers to store in VRAM for the draft model (-1 - use default)
|
int32_t n_gpu_layers_draft = -1; // number of layers to store in VRAM for the draft model (-1 - use default)
|
||||||
|
llama_split_mode split_mode = LLAMA_SPLIT_LAYER; // how to split the model across GPUs
|
||||||
int32_t main_gpu = 0; // the GPU that is used for scratch and small tensors
|
int32_t main_gpu = 0; // the GPU that is used for scratch and small tensors
|
||||||
float tensor_split[LLAMA_MAX_DEVICES] = {0}; // how split tensors should be distributed across GPUs
|
float tensor_split[LLAMA_MAX_DEVICES] = {0}; // how split tensors should be distributed across GPUs
|
||||||
int32_t n_beams = 0; // if non-zero then use beam search of given width.
|
int32_t n_beams = 0; // if non-zero then use beam search of given width.
|
||||||
|
|
|
@ -23,6 +23,15 @@ if 'NO_LOCAL_GGUF' not in os.environ:
|
||||||
import gguf
|
import gguf
|
||||||
|
|
||||||
|
|
||||||
|
# check for any of the given keys in the dictionary and return the value of the first key found
|
||||||
|
def get_key_opts(d, keys):
|
||||||
|
for k in keys:
|
||||||
|
if k in d:
|
||||||
|
return d[k]
|
||||||
|
print(f"Could not find any of {keys}")
|
||||||
|
sys.exit()
|
||||||
|
|
||||||
|
|
||||||
###### MODEL DEFINITIONS ######
|
###### MODEL DEFINITIONS ######
|
||||||
|
|
||||||
class SentencePieceTokenTypes(IntEnum):
|
class SentencePieceTokenTypes(IntEnum):
|
||||||
|
@ -257,10 +266,11 @@ class Model:
|
||||||
toktypes.append(gguf.TokenType.USER_DEFINED)
|
toktypes.append(gguf.TokenType.USER_DEFINED)
|
||||||
elif reverse_vocab[i] in added_vocab:
|
elif reverse_vocab[i] in added_vocab:
|
||||||
tokens.append(reverse_vocab[i])
|
tokens.append(reverse_vocab[i])
|
||||||
if tokenizer.added_tokens_decoder[i].special:
|
if hasattr(tokenizer, "added_tokens_decoder"):
|
||||||
toktypes.append(gguf.TokenType.CONTROL)
|
if tokenizer.added_tokens_decoder[i].special:
|
||||||
else:
|
toktypes.append(gguf.TokenType.CONTROL)
|
||||||
toktypes.append(gguf.TokenType.USER_DEFINED)
|
else:
|
||||||
|
toktypes.append(gguf.TokenType.USER_DEFINED)
|
||||||
else:
|
else:
|
||||||
tokens.append(reverse_vocab[i])
|
tokens.append(reverse_vocab[i])
|
||||||
toktypes.append(gguf.TokenType.NORMAL)
|
toktypes.append(gguf.TokenType.NORMAL)
|
||||||
|
@ -1068,17 +1078,22 @@ class GPT2Model(Model):
|
||||||
|
|
||||||
class Phi2Model(Model):
|
class Phi2Model(Model):
|
||||||
def set_gguf_parameters(self):
|
def set_gguf_parameters(self):
|
||||||
block_count = self.hparams["n_layer"]
|
block_count = get_key_opts(self.hparams, ["num_hidden_layers", "n_layer"])
|
||||||
|
|
||||||
|
rot_pct = get_key_opts(self.hparams, ["partial_rotary_factor"])
|
||||||
|
n_embd = get_key_opts(self.hparams, ["hidden_size", "n_embd"])
|
||||||
|
n_head = get_key_opts(self.hparams, ["num_attention_heads", "n_head"])
|
||||||
|
|
||||||
self.gguf_writer.add_name("Phi2")
|
self.gguf_writer.add_name("Phi2")
|
||||||
self.gguf_writer.add_context_length(self.hparams["n_positions"])
|
self.gguf_writer.add_context_length(get_key_opts(self.hparams, ["n_positions", "max_position_embeddings"]))
|
||||||
self.gguf_writer.add_embedding_length(self.hparams["n_embd"])
|
|
||||||
self.gguf_writer.add_feed_forward_length(4 * self.hparams["n_embd"])
|
self.gguf_writer.add_embedding_length(n_embd)
|
||||||
|
self.gguf_writer.add_feed_forward_length(4 * n_embd)
|
||||||
self.gguf_writer.add_block_count(block_count)
|
self.gguf_writer.add_block_count(block_count)
|
||||||
self.gguf_writer.add_head_count(self.hparams["n_head"])
|
self.gguf_writer.add_head_count(n_head)
|
||||||
self.gguf_writer.add_head_count_kv(self.hparams["n_head"])
|
self.gguf_writer.add_head_count_kv(n_head)
|
||||||
self.gguf_writer.add_layer_norm_eps(self.hparams["layer_norm_epsilon"])
|
self.gguf_writer.add_layer_norm_eps(get_key_opts(self.hparams, ["layer_norm_epsilon", "layer_norm_eps"]))
|
||||||
self.gguf_writer.add_rope_dimension_count(self.hparams["rotary_dim"])
|
self.gguf_writer.add_rope_dimension_count(int(rot_pct * n_embd) // n_head)
|
||||||
self.gguf_writer.add_file_type(self.ftype)
|
self.gguf_writer.add_file_type(self.ftype)
|
||||||
self.gguf_writer.add_add_bos_token(False)
|
self.gguf_writer.add_add_bos_token(False)
|
||||||
|
|
||||||
|
|
|
@ -88,7 +88,10 @@ int main(int argc, char ** argv) {
|
||||||
|
|
||||||
llama_model_params model_params = llama_model_default_params();
|
llama_model_params model_params = llama_model_default_params();
|
||||||
|
|
||||||
|
const std::vector<float> t_split (LLAMA_MAX_DEVICES, 0.0f);
|
||||||
|
|
||||||
model_params.n_gpu_layers = n_gpu_layers;
|
model_params.n_gpu_layers = n_gpu_layers;
|
||||||
|
model_params.tensor_split = t_split.data();
|
||||||
|
|
||||||
llama_model * model = llama_load_model_from_file(params.model.c_str(), model_params);
|
llama_model * model = llama_load_model_from_file(params.model.c_str(), model_params);
|
||||||
|
|
||||||
|
|
|
@ -245,9 +245,8 @@ static struct lora_data * load_lora(struct lora_info * info) {
|
||||||
params_ggml.no_alloc = true;
|
params_ggml.no_alloc = true;
|
||||||
result->ctx = ggml_init(params_ggml);
|
result->ctx = ggml_init(params_ggml);
|
||||||
|
|
||||||
uint32_t LLAMA_FILE_MAGIC_LORA = 0x67676C61; // 'ggla'
|
|
||||||
uint32_t magic = file.read_u32();
|
uint32_t magic = file.read_u32();
|
||||||
if (magic != LLAMA_FILE_MAGIC_LORA) {
|
if (magic != LLAMA_FILE_MAGIC_GGLA) {
|
||||||
die_fmt("unexpected lora header file magic in '%s'", info->filename.c_str());
|
die_fmt("unexpected lora header file magic in '%s'", info->filename.c_str());
|
||||||
}
|
}
|
||||||
uint32_t version = file.read_u32();
|
uint32_t version = file.read_u32();
|
||||||
|
|
|
@ -128,6 +128,25 @@ static std::string get_gpu_info() {
|
||||||
// command line params
|
// command line params
|
||||||
enum output_formats {CSV, JSON, MARKDOWN, SQL};
|
enum output_formats {CSV, JSON, MARKDOWN, SQL};
|
||||||
|
|
||||||
|
static const char * output_format_str(output_formats format) {
|
||||||
|
switch (format) {
|
||||||
|
case CSV: return "csv";
|
||||||
|
case JSON: return "json";
|
||||||
|
case MARKDOWN: return "md";
|
||||||
|
case SQL: return "sql";
|
||||||
|
default: GGML_ASSERT(!"invalid output format");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static const char * split_mode_str(llama_split_mode mode) {
|
||||||
|
switch (mode) {
|
||||||
|
case LLAMA_SPLIT_NONE: return "none";
|
||||||
|
case LLAMA_SPLIT_LAYER: return "layer";
|
||||||
|
case LLAMA_SPLIT_ROW: return "row";
|
||||||
|
default: GGML_ASSERT(!"invalid split mode");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
struct cmd_params {
|
struct cmd_params {
|
||||||
std::vector<std::string> model;
|
std::vector<std::string> model;
|
||||||
std::vector<int> n_prompt;
|
std::vector<int> n_prompt;
|
||||||
|
@ -137,6 +156,7 @@ struct cmd_params {
|
||||||
std::vector<ggml_type> type_v;
|
std::vector<ggml_type> type_v;
|
||||||
std::vector<int> n_threads;
|
std::vector<int> n_threads;
|
||||||
std::vector<int> n_gpu_layers;
|
std::vector<int> n_gpu_layers;
|
||||||
|
std::vector<llama_split_mode> split_mode;
|
||||||
std::vector<int> main_gpu;
|
std::vector<int> main_gpu;
|
||||||
std::vector<bool> no_kv_offload;
|
std::vector<bool> no_kv_offload;
|
||||||
std::vector<bool> mul_mat_q;
|
std::vector<bool> mul_mat_q;
|
||||||
|
@ -155,6 +175,7 @@ static const cmd_params cmd_params_defaults = {
|
||||||
/* type_v */ {GGML_TYPE_F16},
|
/* type_v */ {GGML_TYPE_F16},
|
||||||
/* n_threads */ {get_num_physical_cores()},
|
/* n_threads */ {get_num_physical_cores()},
|
||||||
/* n_gpu_layers */ {99},
|
/* n_gpu_layers */ {99},
|
||||||
|
/* split_mode */ {LLAMA_SPLIT_LAYER},
|
||||||
/* main_gpu */ {0},
|
/* main_gpu */ {0},
|
||||||
/* no_kv_offload */ {false},
|
/* no_kv_offload */ {false},
|
||||||
/* mul_mat_q */ {true},
|
/* mul_mat_q */ {true},
|
||||||
|
@ -169,21 +190,22 @@ static void print_usage(int /* argc */, char ** argv) {
|
||||||
printf("\n");
|
printf("\n");
|
||||||
printf("options:\n");
|
printf("options:\n");
|
||||||
printf(" -h, --help\n");
|
printf(" -h, --help\n");
|
||||||
printf(" -m, --model <filename> (default: %s)\n", join(cmd_params_defaults.model, ",").c_str());
|
printf(" -m, --model <filename> (default: %s)\n", join(cmd_params_defaults.model, ",").c_str());
|
||||||
printf(" -p, --n-prompt <n> (default: %s)\n", join(cmd_params_defaults.n_prompt, ",").c_str());
|
printf(" -p, --n-prompt <n> (default: %s)\n", join(cmd_params_defaults.n_prompt, ",").c_str());
|
||||||
printf(" -n, --n-gen <n> (default: %s)\n", join(cmd_params_defaults.n_gen, ",").c_str());
|
printf(" -n, --n-gen <n> (default: %s)\n", join(cmd_params_defaults.n_gen, ",").c_str());
|
||||||
printf(" -b, --batch-size <n> (default: %s)\n", join(cmd_params_defaults.n_batch, ",").c_str());
|
printf(" -b, --batch-size <n> (default: %s)\n", join(cmd_params_defaults.n_batch, ",").c_str());
|
||||||
printf(" -ctk <t>, --cache-type-k <t> (default: %s)\n", join(transform_to_str(cmd_params_defaults.type_k, ggml_type_name), ",").c_str());
|
printf(" -ctk <t>, --cache-type-k <t> (default: %s)\n", join(transform_to_str(cmd_params_defaults.type_k, ggml_type_name), ",").c_str());
|
||||||
printf(" -ctv <t>, --cache-type-v <t> (default: %s)\n", join(transform_to_str(cmd_params_defaults.type_v, ggml_type_name), ",").c_str());
|
printf(" -ctv <t>, --cache-type-v <t> (default: %s)\n", join(transform_to_str(cmd_params_defaults.type_v, ggml_type_name), ",").c_str());
|
||||||
printf(" -t, --threads <n> (default: %s)\n", join(cmd_params_defaults.n_threads, ",").c_str());
|
printf(" -t, --threads <n> (default: %s)\n", join(cmd_params_defaults.n_threads, ",").c_str());
|
||||||
printf(" -ngl, --n-gpu-layers <n> (default: %s)\n", join(cmd_params_defaults.n_gpu_layers, ",").c_str());
|
printf(" -ngl, --n-gpu-layers <n> (default: %s)\n", join(cmd_params_defaults.n_gpu_layers, ",").c_str());
|
||||||
printf(" -mg, --main-gpu <i> (default: %s)\n", join(cmd_params_defaults.main_gpu, ",").c_str());
|
printf(" -sm, --split-mode <none|layer|row> (default: %s)\n", join(transform_to_str(cmd_params_defaults.split_mode, split_mode_str), ",").c_str());
|
||||||
printf(" -nkvo, --no-kv-offload <0|1> (default: %s)\n", join(cmd_params_defaults.no_kv_offload, ",").c_str());
|
printf(" -mg, --main-gpu <i> (default: %s)\n", join(cmd_params_defaults.main_gpu, ",").c_str());
|
||||||
printf(" -mmq, --mul-mat-q <0|1> (default: %s)\n", join(cmd_params_defaults.mul_mat_q, ",").c_str());
|
printf(" -nkvo, --no-kv-offload <0|1> (default: %s)\n", join(cmd_params_defaults.no_kv_offload, ",").c_str());
|
||||||
printf(" -ts, --tensor_split <ts0/ts1/..> \n");
|
printf(" -mmq, --mul-mat-q <0|1> (default: %s)\n", join(cmd_params_defaults.mul_mat_q, ",").c_str());
|
||||||
printf(" -r, --repetitions <n> (default: %d)\n", cmd_params_defaults.reps);
|
printf(" -ts, --tensor_split <ts0/ts1/..> (default: 0)\n");
|
||||||
printf(" -o, --output <csv|json|md|sql> (default: %s)\n", cmd_params_defaults.output_format == CSV ? "csv" : cmd_params_defaults.output_format == JSON ? "json" : cmd_params_defaults.output_format == MARKDOWN ? "md" : "sql");
|
printf(" -r, --repetitions <n> (default: %d)\n", cmd_params_defaults.reps);
|
||||||
printf(" -v, --verbose (default: %s)\n", cmd_params_defaults.verbose ? "1" : "0");
|
printf(" -o, --output <csv|json|md|sql> (default: %s)\n", output_format_str(cmd_params_defaults.output_format));
|
||||||
|
printf(" -v, --verbose (default: %s)\n", cmd_params_defaults.verbose ? "1" : "0");
|
||||||
printf("\n");
|
printf("\n");
|
||||||
printf("Multiple values can be given for each parameter by separating them with ',' or by specifying the parameter multiple times.\n");
|
printf("Multiple values can be given for each parameter by separating them with ',' or by specifying the parameter multiple times.\n");
|
||||||
}
|
}
|
||||||
|
@ -306,6 +328,28 @@ static cmd_params parse_cmd_params(int argc, char ** argv) {
|
||||||
}
|
}
|
||||||
auto p = split<int>(argv[i], split_delim);
|
auto p = split<int>(argv[i], split_delim);
|
||||||
params.n_gpu_layers.insert(params.n_gpu_layers.end(), p.begin(), p.end());
|
params.n_gpu_layers.insert(params.n_gpu_layers.end(), p.begin(), p.end());
|
||||||
|
} else if (arg == "-sm" || arg == "--split-mode") {
|
||||||
|
if (++i >= argc) {
|
||||||
|
invalid_param = true;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
auto p = split<std::string>(argv[i], split_delim);
|
||||||
|
std::vector<llama_split_mode> modes;
|
||||||
|
for (const auto & m : p) {
|
||||||
|
llama_split_mode mode;
|
||||||
|
if (m == "none") {
|
||||||
|
mode = LLAMA_SPLIT_NONE;
|
||||||
|
} else if (m == "layer") {
|
||||||
|
mode = LLAMA_SPLIT_LAYER;
|
||||||
|
} else if (m == "row") {
|
||||||
|
mode = LLAMA_SPLIT_ROW;
|
||||||
|
} else {
|
||||||
|
invalid_param = true;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
modes.push_back(mode);
|
||||||
|
}
|
||||||
|
params.split_mode.insert(params.split_mode.end(), modes.begin(), modes.end());
|
||||||
} else if (arg == "-mg" || arg == "--main-gpu") {
|
} else if (arg == "-mg" || arg == "--main-gpu") {
|
||||||
if (++i >= argc) {
|
if (++i >= argc) {
|
||||||
invalid_param = true;
|
invalid_param = true;
|
||||||
|
@ -392,6 +436,7 @@ static cmd_params parse_cmd_params(int argc, char ** argv) {
|
||||||
if (params.type_k.empty()) { params.type_k = cmd_params_defaults.type_k; }
|
if (params.type_k.empty()) { params.type_k = cmd_params_defaults.type_k; }
|
||||||
if (params.type_v.empty()) { params.type_v = cmd_params_defaults.type_v; }
|
if (params.type_v.empty()) { params.type_v = cmd_params_defaults.type_v; }
|
||||||
if (params.n_gpu_layers.empty()) { params.n_gpu_layers = cmd_params_defaults.n_gpu_layers; }
|
if (params.n_gpu_layers.empty()) { params.n_gpu_layers = cmd_params_defaults.n_gpu_layers; }
|
||||||
|
if (params.split_mode.empty()) { params.split_mode = cmd_params_defaults.split_mode; }
|
||||||
if (params.main_gpu.empty()) { params.main_gpu = cmd_params_defaults.main_gpu; }
|
if (params.main_gpu.empty()) { params.main_gpu = cmd_params_defaults.main_gpu; }
|
||||||
if (params.no_kv_offload.empty()){ params.no_kv_offload = cmd_params_defaults.no_kv_offload; }
|
if (params.no_kv_offload.empty()){ params.no_kv_offload = cmd_params_defaults.no_kv_offload; }
|
||||||
if (params.mul_mat_q.empty()) { params.mul_mat_q = cmd_params_defaults.mul_mat_q; }
|
if (params.mul_mat_q.empty()) { params.mul_mat_q = cmd_params_defaults.mul_mat_q; }
|
||||||
|
@ -410,6 +455,7 @@ struct cmd_params_instance {
|
||||||
ggml_type type_v;
|
ggml_type type_v;
|
||||||
int n_threads;
|
int n_threads;
|
||||||
int n_gpu_layers;
|
int n_gpu_layers;
|
||||||
|
llama_split_mode split_mode;
|
||||||
int main_gpu;
|
int main_gpu;
|
||||||
bool no_kv_offload;
|
bool no_kv_offload;
|
||||||
bool mul_mat_q;
|
bool mul_mat_q;
|
||||||
|
@ -419,6 +465,7 @@ struct cmd_params_instance {
|
||||||
llama_model_params mparams = llama_model_default_params();
|
llama_model_params mparams = llama_model_default_params();
|
||||||
|
|
||||||
mparams.n_gpu_layers = n_gpu_layers;
|
mparams.n_gpu_layers = n_gpu_layers;
|
||||||
|
mparams.split_mode = split_mode;
|
||||||
mparams.main_gpu = main_gpu;
|
mparams.main_gpu = main_gpu;
|
||||||
mparams.tensor_split = tensor_split.data();
|
mparams.tensor_split = tensor_split.data();
|
||||||
|
|
||||||
|
@ -428,6 +475,7 @@ struct cmd_params_instance {
|
||||||
bool equal_mparams(const cmd_params_instance & other) const {
|
bool equal_mparams(const cmd_params_instance & other) const {
|
||||||
return model == other.model &&
|
return model == other.model &&
|
||||||
n_gpu_layers == other.n_gpu_layers &&
|
n_gpu_layers == other.n_gpu_layers &&
|
||||||
|
split_mode == other.split_mode &&
|
||||||
main_gpu == other.main_gpu &&
|
main_gpu == other.main_gpu &&
|
||||||
tensor_split == other.tensor_split;
|
tensor_split == other.tensor_split;
|
||||||
}
|
}
|
||||||
|
@ -446,45 +494,13 @@ struct cmd_params_instance {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
static std::vector<cmd_params_instance> get_cmd_params_instances_int(const cmd_params & params, int n_gen, int n_prompt) {
|
|
||||||
std::vector<cmd_params_instance> instances;
|
|
||||||
|
|
||||||
for (const auto & m : params.model)
|
|
||||||
for (const auto & nl : params.n_gpu_layers)
|
|
||||||
for (const auto & mg : params.main_gpu)
|
|
||||||
for (const auto & ts : params.tensor_split)
|
|
||||||
for (const auto & nb : params.n_batch)
|
|
||||||
for (const auto & tk : params.type_k)
|
|
||||||
for (const auto & tv : params.type_v)
|
|
||||||
for (const auto & mmq : params.mul_mat_q)
|
|
||||||
for (const auto & nkvo : params.no_kv_offload)
|
|
||||||
for (const auto & nt : params.n_threads) {
|
|
||||||
cmd_params_instance instance = {
|
|
||||||
/* .model = */ m,
|
|
||||||
/* .n_prompt = */ n_prompt,
|
|
||||||
/* .n_gen = */ n_gen,
|
|
||||||
/* .n_batch = */ nb,
|
|
||||||
/* .type_k = */ tk,
|
|
||||||
/* .type_v = */ tv,
|
|
||||||
/* .n_threads = */ nt,
|
|
||||||
/* .n_gpu_layers = */ nl,
|
|
||||||
/* .main_gpu = */ mg,
|
|
||||||
/* .no_kv_offload= */ nkvo,
|
|
||||||
/* .mul_mat_q = */ mmq,
|
|
||||||
/* .tensor_split = */ ts,
|
|
||||||
};
|
|
||||||
instances.push_back(instance);
|
|
||||||
}
|
|
||||||
return instances;
|
|
||||||
}
|
|
||||||
|
|
||||||
static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_params & params) {
|
static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_params & params) {
|
||||||
std::vector<cmd_params_instance> instances;
|
std::vector<cmd_params_instance> instances;
|
||||||
|
|
||||||
#if 1
|
|
||||||
// this ordering minimizes the number of times that each model needs to be reloaded
|
// this ordering minimizes the number of times that each model needs to be reloaded
|
||||||
for (const auto & m : params.model)
|
for (const auto & m : params.model)
|
||||||
for (const auto & nl : params.n_gpu_layers)
|
for (const auto & nl : params.n_gpu_layers)
|
||||||
|
for (const auto & sm : params.split_mode)
|
||||||
for (const auto & mg : params.main_gpu)
|
for (const auto & mg : params.main_gpu)
|
||||||
for (const auto & ts : params.tensor_split)
|
for (const auto & ts : params.tensor_split)
|
||||||
for (const auto & nb : params.n_batch)
|
for (const auto & nb : params.n_batch)
|
||||||
|
@ -506,6 +522,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
|
||||||
/* .type_v = */ tv,
|
/* .type_v = */ tv,
|
||||||
/* .n_threads = */ nt,
|
/* .n_threads = */ nt,
|
||||||
/* .n_gpu_layers = */ nl,
|
/* .n_gpu_layers = */ nl,
|
||||||
|
/* .split_mode = */ sm,
|
||||||
/* .main_gpu = */ mg,
|
/* .main_gpu = */ mg,
|
||||||
/* .no_kv_offload= */ nkvo,
|
/* .no_kv_offload= */ nkvo,
|
||||||
/* .mul_mat_q = */ mmq,
|
/* .mul_mat_q = */ mmq,
|
||||||
|
@ -527,6 +544,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
|
||||||
/* .type_v = */ tv,
|
/* .type_v = */ tv,
|
||||||
/* .n_threads = */ nt,
|
/* .n_threads = */ nt,
|
||||||
/* .n_gpu_layers = */ nl,
|
/* .n_gpu_layers = */ nl,
|
||||||
|
/* .split_mode = */ sm,
|
||||||
/* .main_gpu = */ mg,
|
/* .main_gpu = */ mg,
|
||||||
/* .no_kv_offload= */ nkvo,
|
/* .no_kv_offload= */ nkvo,
|
||||||
/* .mul_mat_q = */ mmq,
|
/* .mul_mat_q = */ mmq,
|
||||||
|
@ -535,24 +553,6 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
|
||||||
instances.push_back(instance);
|
instances.push_back(instance);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
#else
|
|
||||||
// this ordering separates the prompt and generation tests
|
|
||||||
for (const auto & n_prompt : params.n_prompt) {
|
|
||||||
if (n_prompt == 0) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
auto instances_prompt = get_cmd_params_instances_int(params, 0, n_prompt);
|
|
||||||
instances.insert(instances.end(), instances_prompt.begin(), instances_prompt.end());
|
|
||||||
}
|
|
||||||
|
|
||||||
for (const auto & n_gen : params.n_gen) {
|
|
||||||
if (n_gen == 0) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
auto instances_gen = get_cmd_params_instances_int(params, n_gen, 0);
|
|
||||||
instances.insert(instances.end(), instances_gen.begin(), instances_gen.end());
|
|
||||||
}
|
|
||||||
#endif
|
|
||||||
|
|
||||||
return instances;
|
return instances;
|
||||||
}
|
}
|
||||||
|
@ -576,6 +576,7 @@ struct test {
|
||||||
ggml_type type_k;
|
ggml_type type_k;
|
||||||
ggml_type type_v;
|
ggml_type type_v;
|
||||||
int n_gpu_layers;
|
int n_gpu_layers;
|
||||||
|
llama_split_mode split_mode;
|
||||||
int main_gpu;
|
int main_gpu;
|
||||||
bool no_kv_offload;
|
bool no_kv_offload;
|
||||||
bool mul_mat_q;
|
bool mul_mat_q;
|
||||||
|
@ -597,6 +598,7 @@ struct test {
|
||||||
type_k = inst.type_k;
|
type_k = inst.type_k;
|
||||||
type_v = inst.type_v;
|
type_v = inst.type_v;
|
||||||
n_gpu_layers = inst.n_gpu_layers;
|
n_gpu_layers = inst.n_gpu_layers;
|
||||||
|
split_mode = inst.split_mode;
|
||||||
main_gpu = inst.main_gpu;
|
main_gpu = inst.main_gpu;
|
||||||
no_kv_offload = inst.no_kv_offload;
|
no_kv_offload = inst.no_kv_offload;
|
||||||
mul_mat_q = inst.mul_mat_q;
|
mul_mat_q = inst.mul_mat_q;
|
||||||
|
@ -660,7 +662,8 @@ struct test {
|
||||||
"cpu_info", "gpu_info",
|
"cpu_info", "gpu_info",
|
||||||
"model_filename", "model_type", "model_size", "model_n_params",
|
"model_filename", "model_type", "model_size", "model_n_params",
|
||||||
"n_batch", "n_threads", "type_k", "type_v",
|
"n_batch", "n_threads", "type_k", "type_v",
|
||||||
"n_gpu_layers", "main_gpu", "no_kv_offload",
|
"n_gpu_layers", "split_mode",
|
||||||
|
"main_gpu", "no_kv_offload",
|
||||||
"mul_mat_q", "tensor_split",
|
"mul_mat_q", "tensor_split",
|
||||||
"n_prompt", "n_gen", "test_time",
|
"n_prompt", "n_gen", "test_time",
|
||||||
"avg_ns", "stddev_ns",
|
"avg_ns", "stddev_ns",
|
||||||
|
@ -711,7 +714,8 @@ struct test {
|
||||||
cpu_info, gpu_info,
|
cpu_info, gpu_info,
|
||||||
model_filename, model_type, std::to_string(model_size), std::to_string(model_n_params),
|
model_filename, model_type, std::to_string(model_size), std::to_string(model_n_params),
|
||||||
std::to_string(n_batch), std::to_string(n_threads), ggml_type_name(type_k), ggml_type_name(type_v),
|
std::to_string(n_batch), std::to_string(n_threads), ggml_type_name(type_k), ggml_type_name(type_v),
|
||||||
std::to_string(n_gpu_layers), std::to_string(main_gpu), std::to_string(no_kv_offload),
|
std::to_string(n_gpu_layers), split_mode_str(split_mode),
|
||||||
|
std::to_string(main_gpu), std::to_string(no_kv_offload),
|
||||||
std::to_string(mul_mat_q), tensor_split_str,
|
std::to_string(mul_mat_q), tensor_split_str,
|
||||||
std::to_string(n_prompt), std::to_string(n_gen), test_time,
|
std::to_string(n_prompt), std::to_string(n_gen), test_time,
|
||||||
std::to_string(avg_ns()), std::to_string(stdev_ns()),
|
std::to_string(avg_ns()), std::to_string(stdev_ns()),
|
||||||
|
@ -867,6 +871,9 @@ struct markdown_printer : public printer {
|
||||||
if (field == "n_gpu_layers") {
|
if (field == "n_gpu_layers") {
|
||||||
return "ngl";
|
return "ngl";
|
||||||
}
|
}
|
||||||
|
if (field == "split_mode") {
|
||||||
|
return "sm";
|
||||||
|
}
|
||||||
if (field == "n_threads") {
|
if (field == "n_threads") {
|
||||||
return "threads";
|
return "threads";
|
||||||
}
|
}
|
||||||
|
@ -907,6 +914,9 @@ struct markdown_printer : public printer {
|
||||||
if (params.main_gpu.size() > 1 || params.main_gpu != cmd_params_defaults.main_gpu) {
|
if (params.main_gpu.size() > 1 || params.main_gpu != cmd_params_defaults.main_gpu) {
|
||||||
fields.push_back("main_gpu");
|
fields.push_back("main_gpu");
|
||||||
}
|
}
|
||||||
|
if (params.split_mode.size() > 1 || params.split_mode != cmd_params_defaults.split_mode) {
|
||||||
|
fields.push_back("split_mode");
|
||||||
|
}
|
||||||
if (params.mul_mat_q.size() > 1 || params.mul_mat_q != cmd_params_defaults.mul_mat_q) {
|
if (params.mul_mat_q.size() > 1 || params.mul_mat_q != cmd_params_defaults.mul_mat_q) {
|
||||||
fields.push_back("mul_mat_q");
|
fields.push_back("mul_mat_q");
|
||||||
}
|
}
|
||||||
|
|
136
examples/pydantic-models-to-grammar-examples.py
Normal file
136
examples/pydantic-models-to-grammar-examples.py
Normal file
|
@ -0,0 +1,136 @@
|
||||||
|
# Function calling example using pydantic models.
|
||||||
|
|
||||||
|
import json
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Union, Optional
|
||||||
|
|
||||||
|
import requests
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
import importlib
|
||||||
|
from pydantic_models_to_grammar import generate_gbnf_grammar_and_documentation
|
||||||
|
|
||||||
|
# Function to get completion on the llama.cpp server with grammar.
|
||||||
|
def create_completion(prompt, grammar):
|
||||||
|
headers = {"Content-Type": "application/json"}
|
||||||
|
data = {"prompt": prompt, "grammar": grammar}
|
||||||
|
|
||||||
|
response = requests.post("http://127.0.0.1:8080/completion", headers=headers, json=data)
|
||||||
|
data = response.json()
|
||||||
|
|
||||||
|
print(data["content"])
|
||||||
|
return data["content"]
|
||||||
|
|
||||||
|
|
||||||
|
# A function for the agent to send a message to the user.
|
||||||
|
class SendMessageToUser(BaseModel):
|
||||||
|
"""
|
||||||
|
Send a message to the User.
|
||||||
|
"""
|
||||||
|
chain_of_thought: str = Field(..., description="Your chain of thought while sending the message.")
|
||||||
|
message: str = Field(..., description="Message you want to send to the user.")
|
||||||
|
|
||||||
|
def run(self):
|
||||||
|
print(self.message)
|
||||||
|
|
||||||
|
|
||||||
|
# Enum for the calculator function.
|
||||||
|
class MathOperation(Enum):
|
||||||
|
ADD = "add"
|
||||||
|
SUBTRACT = "subtract"
|
||||||
|
MULTIPLY = "multiply"
|
||||||
|
DIVIDE = "divide"
|
||||||
|
|
||||||
|
|
||||||
|
# Very simple calculator tool for the agent.
|
||||||
|
class Calculator(BaseModel):
|
||||||
|
"""
|
||||||
|
Perform a math operation on two numbers.
|
||||||
|
"""
|
||||||
|
number_one: Union[int, float] = Field(..., description="First number.")
|
||||||
|
operation: MathOperation = Field(..., description="Math operation to perform.")
|
||||||
|
number_two: Union[int, float] = Field(..., description="Second number.")
|
||||||
|
|
||||||
|
def run(self):
|
||||||
|
if self.operation == MathOperation.ADD:
|
||||||
|
return self.number_one + self.number_two
|
||||||
|
elif self.operation == MathOperation.SUBTRACT:
|
||||||
|
return self.number_one - self.number_two
|
||||||
|
elif self.operation == MathOperation.MULTIPLY:
|
||||||
|
return self.number_one * self.number_two
|
||||||
|
elif self.operation == MathOperation.DIVIDE:
|
||||||
|
return self.number_one / self.number_two
|
||||||
|
else:
|
||||||
|
raise ValueError("Unknown operation.")
|
||||||
|
|
||||||
|
|
||||||
|
# Here the grammar gets generated by passing the available function models to generate_gbnf_grammar_and_documentation function. This also generates a documentation usable by the LLM.
|
||||||
|
# pydantic_model_list is the list of pydanitc models
|
||||||
|
# outer_object_name is an optional name for an outer object around the actual model object. Like a "function" object with "function_parameters" which contains the actual model object. If None, no outer object will be generated
|
||||||
|
# outer_object_content is the name of outer object content.
|
||||||
|
# model_prefix is the optional prefix for models in the documentation. (Default="Output Model")
|
||||||
|
# fields_prefix is the prefix for the model fields in the documentation. (Default="Output Fields")
|
||||||
|
gbnf_grammar, documentation = generate_gbnf_grammar_and_documentation(
|
||||||
|
pydantic_model_list=[SendMessageToUser, Calculator], outer_object_name="function",
|
||||||
|
outer_object_content="function_parameters", model_prefix="Function", fields_prefix="Parameters")
|
||||||
|
|
||||||
|
print(gbnf_grammar)
|
||||||
|
print(documentation)
|
||||||
|
|
||||||
|
system_message = "You are an advanced AI, tasked to assist the user by calling functions in JSON format. The following are the available functions and their parameters and types:\n\n" + documentation
|
||||||
|
|
||||||
|
user_message = "What is 42 * 42?"
|
||||||
|
prompt = f"<|im_start|>system\n{system_message}<|im_end|>\n<|im_start|>user\n{user_message}<|im_end|>\n<|im_start|>assistant"
|
||||||
|
|
||||||
|
text = create_completion(prompt=prompt, grammar=gbnf_grammar)
|
||||||
|
# This should output something like this:
|
||||||
|
# {
|
||||||
|
# "function": "calculator",
|
||||||
|
# "function_parameters": {
|
||||||
|
# "number_one": 42,
|
||||||
|
# "operation": "multiply",
|
||||||
|
# "number_two": 42
|
||||||
|
# }
|
||||||
|
# }
|
||||||
|
function_dictionary = json.loads(text)
|
||||||
|
if function_dictionary["function"] == "calculator":
|
||||||
|
function_parameters = {**function_dictionary["function_parameters"]}
|
||||||
|
|
||||||
|
print(Calculator(**function_parameters).run())
|
||||||
|
# This should output: 1764
|
||||||
|
|
||||||
|
|
||||||
|
# A example structured output based on pydantic models. The LLM will create an entry for a Book database out of an unstructured text.
|
||||||
|
class Category(Enum):
|
||||||
|
"""
|
||||||
|
The category of the book.
|
||||||
|
"""
|
||||||
|
Fiction = "Fiction"
|
||||||
|
NonFiction = "Non-Fiction"
|
||||||
|
|
||||||
|
|
||||||
|
class Book(BaseModel):
|
||||||
|
"""
|
||||||
|
Represents an entry about a book.
|
||||||
|
"""
|
||||||
|
title: str = Field(..., description="Title of the book.")
|
||||||
|
author: str = Field(..., description="Author of the book.")
|
||||||
|
published_year: Optional[int] = Field(..., description="Publishing year of the book.")
|
||||||
|
keywords: list[str] = Field(..., description="A list of keywords.")
|
||||||
|
category: Category = Field(..., description="Category of the book.")
|
||||||
|
summary: str = Field(..., description="Summary of the book.")
|
||||||
|
|
||||||
|
|
||||||
|
# We need no additional parameters other than our list of pydantic models.
|
||||||
|
gbnf_grammar, documentation = generate_gbnf_grammar_and_documentation([Book])
|
||||||
|
|
||||||
|
system_message = "You are an advanced AI, tasked to create a dataset entry in JSON for a Book. The following is the expected output model:\n\n" + documentation
|
||||||
|
|
||||||
|
text = """The Feynman Lectures on Physics is a physics textbook based on some lectures by Richard Feynman, a Nobel laureate who has sometimes been called "The Great Explainer". The lectures were presented before undergraduate students at the California Institute of Technology (Caltech), during 1961–1963. The book's co-authors are Feynman, Robert B. Leighton, and Matthew Sands."""
|
||||||
|
prompt = f"<|im_start|>system\n{system_message}<|im_end|>\n<|im_start|>user\n{text}<|im_end|>\n<|im_start|>assistant"
|
||||||
|
|
||||||
|
text = create_completion(prompt=prompt, grammar=gbnf_grammar)
|
||||||
|
|
||||||
|
json_data = json.loads(text)
|
||||||
|
|
||||||
|
print(Book(**json_data))
|
1151
examples/pydantic_models_to_grammar.py
Normal file
1151
examples/pydantic_models_to_grammar.py
Normal file
File diff suppressed because it is too large
Load diff
|
@ -1350,14 +1350,17 @@ struct llama_server_context
|
||||||
res.result_json["model"] = slot.oaicompat_model;
|
res.result_json["model"] = slot.oaicompat_model;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
queue_results.push_back(res);
|
||||||
|
condition_results.notify_all();
|
||||||
|
|
||||||
|
// done with results, unlock
|
||||||
|
lock.unlock();
|
||||||
|
|
||||||
// parent multitask, if any, needs to be updated
|
// parent multitask, if any, needs to be updated
|
||||||
if (slot.multitask_id != -1)
|
if (slot.multitask_id != -1)
|
||||||
{
|
{
|
||||||
update_multi_task(slot.multitask_id, slot.task_id, res);
|
update_multi_task(slot.multitask_id, slot.task_id, res);
|
||||||
}
|
}
|
||||||
|
|
||||||
queue_results.push_back(res);
|
|
||||||
condition_results.notify_all();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void send_embedding(llama_client_slot &slot)
|
void send_embedding(llama_client_slot &slot)
|
||||||
|
@ -1603,6 +1606,7 @@ struct llama_server_context
|
||||||
}
|
}
|
||||||
|
|
||||||
// remove finished multitasks from the queue of multitasks, and add the corresponding result to the result queue
|
// remove finished multitasks from the queue of multitasks, and add the corresponding result to the result queue
|
||||||
|
std::vector<task_result> agg_results;
|
||||||
auto queue_iterator = queue_multitasks.begin();
|
auto queue_iterator = queue_multitasks.begin();
|
||||||
while (queue_iterator != queue_multitasks.end())
|
while (queue_iterator != queue_multitasks.end())
|
||||||
{
|
{
|
||||||
|
@ -1623,8 +1627,9 @@ struct llama_server_context
|
||||||
}
|
}
|
||||||
aggregate_result.result_json = json{ "results", result_jsons };
|
aggregate_result.result_json = json{ "results", result_jsons };
|
||||||
|
|
||||||
std::lock_guard<std::mutex> lock(mutex_results);
|
|
||||||
queue_results.push_back(aggregate_result);
|
agg_results.push_back(aggregate_result);
|
||||||
|
|
||||||
condition_results.notify_all();
|
condition_results.notify_all();
|
||||||
|
|
||||||
queue_iterator = queue_multitasks.erase(queue_iterator);
|
queue_iterator = queue_multitasks.erase(queue_iterator);
|
||||||
|
@ -1634,6 +1639,13 @@ struct llama_server_context
|
||||||
++queue_iterator;
|
++queue_iterator;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// done with tasks, unlock
|
||||||
|
lock.unlock();
|
||||||
|
|
||||||
|
// copy aggregate results of complete multi-tasks to the results queue
|
||||||
|
std::lock_guard<std::mutex> lock_results(mutex_results);
|
||||||
|
queue_results.insert(queue_results.end(), agg_results.begin(), agg_results.end());
|
||||||
}
|
}
|
||||||
|
|
||||||
bool update_slots() {
|
bool update_slots() {
|
||||||
|
@ -1835,7 +1847,7 @@ struct llama_server_context
|
||||||
|
|
||||||
slot.cache_tokens = prompt_tokens;
|
slot.cache_tokens = prompt_tokens;
|
||||||
|
|
||||||
if (slot.n_past == slot.num_prompt_tokens)
|
if (slot.n_past == slot.num_prompt_tokens && slot.n_past > 0)
|
||||||
{
|
{
|
||||||
// we have to evaluate at least 1 token to generate logits.
|
// we have to evaluate at least 1 token to generate logits.
|
||||||
LOG_TEE("slot %d : we have to evaluate at least 1 token to generate logits\n", slot.id);
|
LOG_TEE("slot %d : we have to evaluate at least 1 token to generate logits\n", slot.id);
|
||||||
|
@ -2005,12 +2017,15 @@ static void server_print_usage(const char *argv0, const gpt_params ¶ms,
|
||||||
#ifdef LLAMA_SUPPORTS_GPU_OFFLOAD
|
#ifdef LLAMA_SUPPORTS_GPU_OFFLOAD
|
||||||
printf(" -ngl N, --n-gpu-layers N\n");
|
printf(" -ngl N, --n-gpu-layers N\n");
|
||||||
printf(" number of layers to store in VRAM\n");
|
printf(" number of layers to store in VRAM\n");
|
||||||
|
printf(" -sm SPLIT_MODE, --split-mode SPLIT_MODE\n");
|
||||||
|
printf(" how to split the model across multiple GPUs, one of:\n");
|
||||||
|
printf(" - none: use one GPU only\n");
|
||||||
|
printf(" - layer (default): split layers and KV across GPUs\n");
|
||||||
|
printf(" - row: split rows across GPUs\n");
|
||||||
printf(" -ts SPLIT --tensor-split SPLIT\n");
|
printf(" -ts SPLIT --tensor-split SPLIT\n");
|
||||||
printf(" how to split tensors across multiple GPUs, comma-separated list of proportions, e.g. 3,1\n");
|
printf(" fraction of the model to offload to each GPU, comma-separated list of proportions, e.g. 3,1\n");
|
||||||
printf(" -mg i, --main-gpu i the GPU to use for scratch and small tensors\n");
|
printf(" -mg i, --main-gpu i the GPU to use for the model (with split-mode = none),\n");
|
||||||
printf(" -nommq, --no-mul-mat-q\n");
|
printf(" or for intermediate results and KV (with split-mode = row)\n");
|
||||||
printf(" use cuBLAS instead of custom mul_mat_q CUDA kernels.\n");
|
|
||||||
printf(" Not recommended since this is both slower and uses more VRAM.\n");
|
|
||||||
#endif
|
#endif
|
||||||
printf(" -m FNAME, --model FNAME\n");
|
printf(" -m FNAME, --model FNAME\n");
|
||||||
printf(" model path (default: %s)\n", params.model.c_str());
|
printf(" model path (default: %s)\n", params.model.c_str());
|
||||||
|
@ -2253,6 +2268,33 @@ static void server_params_parse(int argc, char **argv, server_params &sparams,
|
||||||
"See main README.md for information on enabling GPU BLAS support",
|
"See main README.md for information on enabling GPU BLAS support",
|
||||||
{{"n_gpu_layers", params.n_gpu_layers}});
|
{{"n_gpu_layers", params.n_gpu_layers}});
|
||||||
#endif
|
#endif
|
||||||
|
}
|
||||||
|
else if (arg == "--split-mode" || arg == "-sm")
|
||||||
|
{
|
||||||
|
if (++i >= argc) {
|
||||||
|
invalid_param = true;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
std::string arg_next = argv[i];
|
||||||
|
if (arg_next == "none")
|
||||||
|
{
|
||||||
|
params.split_mode = LLAMA_SPLIT_NONE;
|
||||||
|
}
|
||||||
|
else if (arg_next == "layer")
|
||||||
|
{
|
||||||
|
params.split_mode = LLAMA_SPLIT_LAYER;
|
||||||
|
}
|
||||||
|
else if (arg_next == "row")
|
||||||
|
{
|
||||||
|
params.split_mode = LLAMA_SPLIT_ROW;
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
invalid_param = true;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
#ifndef GGML_USE_CUBLAS
|
||||||
|
fprintf(stderr, "warning: llama.cpp was compiled without cuBLAS. Setting the split mode has no effect.\n");
|
||||||
|
#endif // GGML_USE_CUBLAS
|
||||||
}
|
}
|
||||||
else if (arg == "--tensor-split" || arg == "-ts")
|
else if (arg == "--tensor-split" || arg == "-ts")
|
||||||
{
|
{
|
||||||
|
|
34
ggml-alloc.c
34
ggml-alloc.c
|
@ -102,8 +102,6 @@ void ggml_tallocr_alloc(ggml_tallocr_t alloc, struct ggml_tensor * tensor) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
AT_PRINTF("block %d\n", best_fit_block);
|
|
||||||
|
|
||||||
if (best_fit_block == -1) {
|
if (best_fit_block == -1) {
|
||||||
// the last block is our last resort
|
// the last block is our last resort
|
||||||
struct free_block * block = &alloc->free_blocks[alloc->n_free_blocks - 1];
|
struct free_block * block = &alloc->free_blocks[alloc->n_free_blocks - 1];
|
||||||
|
@ -117,6 +115,7 @@ void ggml_tallocr_alloc(ggml_tallocr_t alloc, struct ggml_tensor * tensor) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
struct free_block * block = &alloc->free_blocks[best_fit_block];
|
struct free_block * block = &alloc->free_blocks[best_fit_block];
|
||||||
void * addr = block->addr;
|
void * addr = block->addr;
|
||||||
block->addr = (char*)block->addr + size;
|
block->addr = (char*)block->addr + size;
|
||||||
|
@ -129,6 +128,8 @@ void ggml_tallocr_alloc(ggml_tallocr_t alloc, struct ggml_tensor * tensor) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
AT_PRINTF("block %d, addr %p\n", best_fit_block, addr);
|
||||||
|
|
||||||
tensor->data = addr;
|
tensor->data = addr;
|
||||||
tensor->buffer = alloc->buffer;
|
tensor->buffer = alloc->buffer;
|
||||||
if (!alloc->measure) {
|
if (!alloc->measure) {
|
||||||
|
@ -229,6 +230,7 @@ void ggml_tallocr_reset(ggml_tallocr_t alloc) {
|
||||||
alloc->free_blocks[0].size = SIZE_MAX/2; // restrict maximum size of a measure allocator to half size_t max to avoid overflows
|
alloc->free_blocks[0].size = SIZE_MAX/2; // restrict maximum size of a measure allocator to half size_t max to avoid overflows
|
||||||
} else {
|
} else {
|
||||||
alloc->free_blocks[0].size = ggml_backend_buffer_get_size(alloc->buffer) - align_offset;
|
alloc->free_blocks[0].size = ggml_backend_buffer_get_size(alloc->buffer) - align_offset;
|
||||||
|
ggml_backend_buffer_reset(alloc->buffer);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -263,9 +265,9 @@ ggml_tallocr_t ggml_tallocr_new_measure(size_t alignment) {
|
||||||
return alloc;
|
return alloc;
|
||||||
}
|
}
|
||||||
|
|
||||||
ggml_tallocr_t ggml_tallocr_new_measure_from_backend(struct ggml_backend * backend) {
|
ggml_tallocr_t ggml_tallocr_new_measure_from_buft(struct ggml_backend_buffer_type * buft) {
|
||||||
// create a backend buffer to get the correct tensor allocation sizes
|
// create a backend buffer to get the correct tensor allocation sizes
|
||||||
ggml_backend_buffer_t buffer = ggml_backend_alloc_buffer(backend, 1);
|
ggml_backend_buffer_t buffer = ggml_backend_buft_alloc_buffer(buft, 1);
|
||||||
|
|
||||||
// TODO: move alloc initialization to a common ggml_tallocr_new_impl function
|
// TODO: move alloc initialization to a common ggml_tallocr_new_impl function
|
||||||
ggml_tallocr_t alloc = ggml_tallocr_new_from_buffer(buffer);
|
ggml_tallocr_t alloc = ggml_tallocr_new_from_buffer(buffer);
|
||||||
|
@ -275,13 +277,22 @@ ggml_tallocr_t ggml_tallocr_new_measure_from_backend(struct ggml_backend * backe
|
||||||
return alloc;
|
return alloc;
|
||||||
}
|
}
|
||||||
|
|
||||||
ggml_tallocr_t ggml_tallocr_new_from_backend(struct ggml_backend * backend, size_t size) {
|
ggml_tallocr_t ggml_tallocr_new_measure_from_backend(struct ggml_backend * backend) {
|
||||||
ggml_backend_buffer_t buffer = ggml_backend_alloc_buffer(backend, size);
|
return ggml_tallocr_new_measure_from_buft(ggml_backend_get_default_buffer_type(backend));
|
||||||
|
}
|
||||||
|
|
||||||
|
ggml_tallocr_t ggml_tallocr_new_from_buft(struct ggml_backend_buffer_type * buft, size_t size) {
|
||||||
|
// create a backend buffer to get the correct tensor allocation sizes
|
||||||
|
ggml_backend_buffer_t buffer = ggml_backend_buft_alloc_buffer(buft, size);
|
||||||
ggml_tallocr_t alloc = ggml_tallocr_new_from_buffer(buffer);
|
ggml_tallocr_t alloc = ggml_tallocr_new_from_buffer(buffer);
|
||||||
alloc->buffer_owned = true;
|
alloc->buffer_owned = true;
|
||||||
return alloc;
|
return alloc;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ggml_tallocr_t ggml_tallocr_new_from_backend(struct ggml_backend * backend, size_t size) {
|
||||||
|
return ggml_tallocr_new_from_buft(ggml_backend_get_default_buffer_type(backend), size);
|
||||||
|
}
|
||||||
|
|
||||||
ggml_tallocr_t ggml_tallocr_new_from_buffer(struct ggml_backend_buffer * buffer) {
|
ggml_tallocr_t ggml_tallocr_new_from_buffer(struct ggml_backend_buffer * buffer) {
|
||||||
ggml_tallocr_t alloc = (ggml_tallocr_t)malloc(sizeof(struct ggml_tallocr));
|
ggml_tallocr_t alloc = (ggml_tallocr_t)malloc(sizeof(struct ggml_tallocr));
|
||||||
|
|
||||||
|
@ -779,10 +790,21 @@ ggml_backend_buffer_t ggml_backend_alloc_ctx_tensors_from_buft(struct ggml_conte
|
||||||
|
|
||||||
if (nbytes == 0) {
|
if (nbytes == 0) {
|
||||||
// all the tensors in the context are already allocated
|
// all the tensors in the context are already allocated
|
||||||
|
#ifndef NDEBUG
|
||||||
|
fprintf(stderr, "%s: all tensors in the context are already allocated\n", __func__);
|
||||||
|
#endif
|
||||||
return NULL;
|
return NULL;
|
||||||
}
|
}
|
||||||
|
|
||||||
ggml_backend_buffer_t buffer = ggml_backend_buft_alloc_buffer(buft, nbytes);
|
ggml_backend_buffer_t buffer = ggml_backend_buft_alloc_buffer(buft, nbytes);
|
||||||
|
if (buffer == NULL) {
|
||||||
|
// failed to allocate buffer
|
||||||
|
#ifndef NDEBUG
|
||||||
|
fprintf(stderr, "%s: failed to allocate buffer\n", __func__);
|
||||||
|
#endif
|
||||||
|
return NULL;
|
||||||
|
}
|
||||||
|
|
||||||
ggml_tallocr_t tallocr = ggml_tallocr_new_from_buffer(buffer);
|
ggml_tallocr_t tallocr = ggml_tallocr_new_from_buffer(buffer);
|
||||||
|
|
||||||
for (struct ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
|
for (struct ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
|
||||||
|
|
|
@ -52,8 +52,10 @@ typedef struct ggml_tallocr * ggml_tallocr_t;
|
||||||
|
|
||||||
GGML_API ggml_tallocr_t ggml_tallocr_new(void * data, size_t size, size_t alignment);
|
GGML_API ggml_tallocr_t ggml_tallocr_new(void * data, size_t size, size_t alignment);
|
||||||
GGML_API ggml_tallocr_t ggml_tallocr_new_measure(size_t alignment);
|
GGML_API ggml_tallocr_t ggml_tallocr_new_measure(size_t alignment);
|
||||||
GGML_API ggml_tallocr_t ggml_tallocr_new_from_buffer(struct ggml_backend_buffer * buffer);
|
GGML_API ggml_tallocr_t ggml_tallocr_new_from_buft(struct ggml_backend_buffer_type * buft, size_t size);
|
||||||
GGML_API ggml_tallocr_t ggml_tallocr_new_from_backend(struct ggml_backend * backend, size_t size); // allocates an owned buffer
|
GGML_API ggml_tallocr_t ggml_tallocr_new_from_backend(struct ggml_backend * backend, size_t size); // allocates an owned buffer
|
||||||
|
GGML_API ggml_tallocr_t ggml_tallocr_new_from_buffer(struct ggml_backend_buffer * buffer);
|
||||||
|
GGML_API ggml_tallocr_t ggml_tallocr_new_measure_from_buft(struct ggml_backend_buffer_type * buft);
|
||||||
GGML_API ggml_tallocr_t ggml_tallocr_new_measure_from_backend(struct ggml_backend * backend);
|
GGML_API ggml_tallocr_t ggml_tallocr_new_measure_from_backend(struct ggml_backend * backend);
|
||||||
|
|
||||||
GGML_API struct ggml_backend_buffer * ggml_tallocr_get_buffer(ggml_tallocr_t talloc);
|
GGML_API struct ggml_backend_buffer * ggml_tallocr_get_buffer(ggml_tallocr_t talloc);
|
||||||
|
|
|
@ -16,9 +16,10 @@ extern "C" {
|
||||||
typedef void * ggml_backend_buffer_type_context_t;
|
typedef void * ggml_backend_buffer_type_context_t;
|
||||||
|
|
||||||
struct ggml_backend_buffer_type_i {
|
struct ggml_backend_buffer_type_i {
|
||||||
|
const char * (*get_name) (ggml_backend_buffer_type_t buft);
|
||||||
ggml_backend_buffer_t (*alloc_buffer) (ggml_backend_buffer_type_t buft, size_t size);
|
ggml_backend_buffer_t (*alloc_buffer) (ggml_backend_buffer_type_t buft, size_t size);
|
||||||
size_t (*get_alignment) (ggml_backend_buffer_type_t buft); // tensor alignment
|
size_t (*get_alignment) (ggml_backend_buffer_type_t buft); // tensor alignment
|
||||||
size_t (*get_alloc_size) (ggml_backend_buffer_type_t buft, struct ggml_tensor * tensor); // data size needed to allocate the tensor, including padding
|
size_t (*get_alloc_size) (ggml_backend_buffer_type_t buft, const struct ggml_tensor * tensor); // data size needed to allocate the tensor, including padding
|
||||||
bool (*supports_backend)(ggml_backend_buffer_type_t buft, ggml_backend_t backend); // check if the buffer type is usable by the backend
|
bool (*supports_backend)(ggml_backend_buffer_type_t buft, ggml_backend_t backend); // check if the buffer type is usable by the backend
|
||||||
// check if tensor data is in host memory
|
// check if tensor data is in host memory
|
||||||
// should be equivalent to supports_backend(buft, ggml_backend_cpu_init())
|
// should be equivalent to supports_backend(buft, ggml_backend_cpu_init())
|
||||||
|
@ -34,16 +35,15 @@ extern "C" {
|
||||||
typedef void * ggml_backend_buffer_context_t;
|
typedef void * ggml_backend_buffer_context_t;
|
||||||
|
|
||||||
struct ggml_backend_buffer_i {
|
struct ggml_backend_buffer_i {
|
||||||
void (*free_buffer) (ggml_backend_buffer_t buffer);
|
const char * (*get_name) (ggml_backend_buffer_t buffer);
|
||||||
//void (*reset) (ggml_backend_buffer_t buffer); // reset any internal state due to tensor initialization, such as tensor extras
|
void (*free_buffer)(ggml_backend_buffer_t buffer);
|
||||||
void * (*get_base) (ggml_backend_buffer_t buffer);
|
void * (*get_base) (ggml_backend_buffer_t buffer);
|
||||||
void (*init_tensor) (ggml_backend_buffer_t buffer, struct ggml_tensor * tensor);
|
void (*init_tensor)(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor);
|
||||||
void (*set_tensor) (ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size);
|
void (*set_tensor) (ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size);
|
||||||
void (*get_tensor) (ggml_backend_buffer_t buffer, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size);
|
void (*get_tensor) (ggml_backend_buffer_t buffer, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size);
|
||||||
// (optional) copy tensor between different buffer-type, allow for single-copy tranfers
|
bool (*cpy_tensor) (ggml_backend_buffer_t buffer, const struct ggml_tensor * src, struct ggml_tensor * dst); // dst is in the buffer, src may be in any buffer
|
||||||
void (*cpy_tensor_from)(ggml_backend_buffer_t buffer, struct ggml_tensor * src, struct ggml_tensor * dst);
|
void (*clear) (ggml_backend_buffer_t buffer, uint8_t value);
|
||||||
void (*cpy_tensor_to) (ggml_backend_buffer_t buffer, struct ggml_tensor * src, struct ggml_tensor * dst);
|
void (*reset) (ggml_backend_buffer_t buffer); // reset any internal state due to tensor initialization, such as tensor extras
|
||||||
void (*clear) (ggml_backend_buffer_t buffer, uint8_t value);
|
|
||||||
};
|
};
|
||||||
|
|
||||||
struct ggml_backend_buffer {
|
struct ggml_backend_buffer {
|
||||||
|
@ -51,6 +51,7 @@ extern "C" {
|
||||||
ggml_backend_buffer_type_t buft;
|
ggml_backend_buffer_type_t buft;
|
||||||
ggml_backend_buffer_context_t context;
|
ggml_backend_buffer_context_t context;
|
||||||
size_t size;
|
size_t size;
|
||||||
|
enum ggml_backend_buffer_usage usage;
|
||||||
};
|
};
|
||||||
|
|
||||||
ggml_backend_buffer_t ggml_backend_buffer_init(
|
ggml_backend_buffer_t ggml_backend_buffer_init(
|
||||||
|
@ -59,6 +60,8 @@ extern "C" {
|
||||||
ggml_backend_buffer_context_t context,
|
ggml_backend_buffer_context_t context,
|
||||||
size_t size);
|
size_t size);
|
||||||
|
|
||||||
|
// do not use directly, use ggml_backend_tensor_copy instead
|
||||||
|
bool ggml_backend_buffer_copy_tensor(const struct ggml_tensor * src, struct ggml_tensor * dst);
|
||||||
|
|
||||||
//
|
//
|
||||||
// Backend
|
// Backend
|
||||||
|
@ -74,22 +77,20 @@ extern "C" {
|
||||||
// buffer allocation
|
// buffer allocation
|
||||||
ggml_backend_buffer_type_t (*get_default_buffer_type)(ggml_backend_t backend);
|
ggml_backend_buffer_type_t (*get_default_buffer_type)(ggml_backend_t backend);
|
||||||
|
|
||||||
// (optional) asynchroneous tensor data access
|
// (optional) asynchronous tensor data access
|
||||||
void (*set_tensor_async)(ggml_backend_t backend, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size);
|
void (*set_tensor_async)(ggml_backend_t backend, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size);
|
||||||
void (*get_tensor_async)(ggml_backend_t backend, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size);
|
void (*get_tensor_async)(ggml_backend_t backend, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size);
|
||||||
|
bool (*cpy_tensor_async)(ggml_backend_t backend, const struct ggml_tensor * src, struct ggml_tensor * dst);
|
||||||
|
|
||||||
// (optional) asynchroneous tensor copy
|
// (optional) complete all pending operations
|
||||||
void (*cpy_tensor_from_async)(ggml_backend_t backend, struct ggml_tensor * src, struct ggml_tensor * dst);
|
|
||||||
void (*cpy_tensor_to_async) (ggml_backend_t backend, struct ggml_tensor * src, struct ggml_tensor * dst);
|
|
||||||
|
|
||||||
void (*synchronize)(ggml_backend_t backend);
|
void (*synchronize)(ggml_backend_t backend);
|
||||||
|
|
||||||
// compute graph with a plan
|
// compute graph with a plan
|
||||||
ggml_backend_graph_plan_t (*graph_plan_create) (ggml_backend_t backend, struct ggml_cgraph * cgraph);
|
ggml_backend_graph_plan_t (*graph_plan_create) (ggml_backend_t backend, const struct ggml_cgraph * cgraph);
|
||||||
void (*graph_plan_free) (ggml_backend_t backend, ggml_backend_graph_plan_t plan);
|
void (*graph_plan_free) (ggml_backend_t backend, ggml_backend_graph_plan_t plan);
|
||||||
void (*graph_plan_compute)(ggml_backend_t backend, ggml_backend_graph_plan_t plan);
|
void (*graph_plan_compute)(ggml_backend_t backend, ggml_backend_graph_plan_t plan);
|
||||||
|
|
||||||
// compute graph without a plan
|
// compute graph without a plan (async)
|
||||||
bool (*graph_compute)(ggml_backend_t backend, struct ggml_cgraph * cgraph);
|
bool (*graph_compute)(ggml_backend_t backend, struct ggml_cgraph * cgraph);
|
||||||
|
|
||||||
// check if the backend supports an operation
|
// check if the backend supports an operation
|
||||||
|
@ -102,7 +103,6 @@ extern "C" {
|
||||||
ggml_backend_context_t context;
|
ggml_backend_context_t context;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
//
|
//
|
||||||
// Backend registry
|
// Backend registry
|
||||||
//
|
//
|
||||||
|
|
713
ggml-backend.c
713
ggml-backend.c
File diff suppressed because it is too large
Load diff
|
@ -17,22 +17,31 @@ extern "C" {
|
||||||
//
|
//
|
||||||
|
|
||||||
// buffer type
|
// buffer type
|
||||||
GGML_API ggml_backend_buffer_t ggml_backend_buft_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size);
|
GGML_API const char * ggml_backend_buft_name (ggml_backend_buffer_type_t buft);
|
||||||
GGML_API size_t ggml_backend_buft_get_alignment (ggml_backend_buffer_type_t buft);
|
GGML_API ggml_backend_buffer_t ggml_backend_buft_alloc_buffer (ggml_backend_buffer_type_t buft, size_t size);
|
||||||
GGML_API size_t ggml_backend_buft_get_alloc_size(ggml_backend_buffer_type_t buft, struct ggml_tensor * tensor);
|
GGML_API size_t ggml_backend_buft_get_alignment (ggml_backend_buffer_type_t buft);
|
||||||
GGML_API bool ggml_backend_buft_supports_backend(ggml_backend_buffer_type_t buft, ggml_backend_t backend);
|
GGML_API size_t ggml_backend_buft_get_alloc_size (ggml_backend_buffer_type_t buft, struct ggml_tensor * tensor);
|
||||||
GGML_API bool ggml_backend_buft_is_host (ggml_backend_buffer_type_t buft);
|
GGML_API bool ggml_backend_buft_supports_backend(ggml_backend_buffer_type_t buft, ggml_backend_t backend);
|
||||||
|
GGML_API bool ggml_backend_buft_is_host (ggml_backend_buffer_type_t buft);
|
||||||
|
|
||||||
// buffer
|
// buffer
|
||||||
GGML_API void ggml_backend_buffer_free (ggml_backend_buffer_t buffer);
|
enum ggml_backend_buffer_usage {
|
||||||
GGML_API void * ggml_backend_buffer_get_base (ggml_backend_buffer_t buffer);
|
GGML_BACKEND_BUFFER_USAGE_ANY = 0,
|
||||||
GGML_API size_t ggml_backend_buffer_get_size (ggml_backend_buffer_t buffer);
|
GGML_BACKEND_BUFFER_USAGE_WEIGHTS = 1,
|
||||||
GGML_API void ggml_backend_buffer_init_tensor (ggml_backend_buffer_t buffer, struct ggml_tensor * tensor);
|
};
|
||||||
GGML_API size_t ggml_backend_buffer_get_alignment (ggml_backend_buffer_t buffer);
|
|
||||||
GGML_API size_t ggml_backend_buffer_get_alloc_size(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor);
|
GGML_API const char * ggml_backend_buffer_name (ggml_backend_buffer_t buffer);
|
||||||
GGML_API void ggml_backend_buffer_clear (ggml_backend_buffer_t buffer, uint8_t value);
|
GGML_API void ggml_backend_buffer_free (ggml_backend_buffer_t buffer);
|
||||||
GGML_API bool ggml_backend_buffer_is_host (ggml_backend_buffer_t buffer);
|
GGML_API void * ggml_backend_buffer_get_base (ggml_backend_buffer_t buffer);
|
||||||
GGML_API ggml_backend_buffer_type_t ggml_backend_buffer_type(ggml_backend_buffer_t buffer);
|
GGML_API size_t ggml_backend_buffer_get_size (ggml_backend_buffer_t buffer);
|
||||||
|
GGML_API void ggml_backend_buffer_init_tensor (ggml_backend_buffer_t buffer, struct ggml_tensor * tensor);
|
||||||
|
GGML_API size_t ggml_backend_buffer_get_alignment (ggml_backend_buffer_t buffer);
|
||||||
|
GGML_API size_t ggml_backend_buffer_get_alloc_size(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor);
|
||||||
|
GGML_API void ggml_backend_buffer_clear (ggml_backend_buffer_t buffer, uint8_t value);
|
||||||
|
GGML_API bool ggml_backend_buffer_is_host (ggml_backend_buffer_t buffer);
|
||||||
|
GGML_API void ggml_backend_buffer_set_usage (ggml_backend_buffer_t buffer, enum ggml_backend_buffer_usage usage);
|
||||||
|
GGML_API ggml_backend_buffer_type_t ggml_backend_buffer_get_type (ggml_backend_buffer_t buffer);
|
||||||
|
GGML_API void ggml_backend_buffer_reset (ggml_backend_buffer_t buffer);
|
||||||
|
|
||||||
//
|
//
|
||||||
// Backend
|
// Backend
|
||||||
|
@ -140,23 +149,24 @@ extern "C" {
|
||||||
typedef struct ggml_backend_sched * ggml_backend_sched_t;
|
typedef struct ggml_backend_sched * ggml_backend_sched_t;
|
||||||
|
|
||||||
// Initialize a backend scheduler
|
// Initialize a backend scheduler
|
||||||
GGML_API ggml_backend_sched_t ggml_backend_sched_new(ggml_backend_t * backends, int n_backends);
|
GGML_API ggml_backend_sched_t ggml_backend_sched_new(ggml_backend_t * backends, ggml_backend_buffer_type_t * bufts, int n_backends, size_t graph_size);
|
||||||
|
GGML_API void ggml_backend_sched_free(ggml_backend_sched_t sched);
|
||||||
GGML_API void ggml_backend_sched_free(ggml_backend_sched_t sched);
|
|
||||||
|
|
||||||
// Initialize backend buffers from a measure graph
|
// Initialize backend buffers from a measure graph
|
||||||
GGML_API void ggml_backend_sched_init_measure(ggml_backend_sched_t sched, struct ggml_cgraph * measure_graph);
|
GGML_API void ggml_backend_sched_init_measure(ggml_backend_sched_t sched, struct ggml_cgraph * measure_graph);
|
||||||
|
// Get the number of splits of the last graph
|
||||||
|
GGML_API int ggml_backend_sched_get_n_splits(ggml_backend_sched_t sched);
|
||||||
|
|
||||||
GGML_API ggml_tallocr_t ggml_backend_sched_get_tallocr(ggml_backend_sched_t sched, ggml_backend_t backend);
|
GGML_API ggml_tallocr_t ggml_backend_sched_get_tallocr(ggml_backend_sched_t sched, ggml_backend_t backend);
|
||||||
GGML_API ggml_backend_buffer_t ggml_backend_sched_get_buffer (ggml_backend_sched_t sched, ggml_backend_t backend);
|
GGML_API ggml_backend_buffer_t ggml_backend_sched_get_buffer (ggml_backend_sched_t sched, ggml_backend_t backend);
|
||||||
|
|
||||||
GGML_API void ggml_backend_sched_set_node_backend(ggml_backend_sched_t sched, struct ggml_tensor * node, ggml_backend_t backend);
|
GGML_API void ggml_backend_sched_set_node_backend(ggml_backend_sched_t sched, struct ggml_tensor * node, ggml_backend_t backend);
|
||||||
|
GGML_API ggml_backend_t ggml_backend_sched_get_node_backend(ggml_backend_sched_t sched, struct ggml_tensor * node);
|
||||||
|
|
||||||
// Allocate a graph on the backend scheduler
|
// Allocate and compute graph on the backend scheduler
|
||||||
GGML_API void ggml_backend_sched_graph_compute(
|
GGML_API void ggml_backend_sched_graph_compute(ggml_backend_sched_t sched, struct ggml_cgraph * graph);
|
||||||
ggml_backend_sched_t sched,
|
|
||||||
struct ggml_cgraph * graph);
|
|
||||||
|
|
||||||
|
// Reset all assignments and allocators - must be called before using the sched allocators to allocate inputs
|
||||||
|
GGML_API void ggml_backend_sched_reset(ggml_backend_sched_t sched);
|
||||||
|
|
||||||
//
|
//
|
||||||
// Utils
|
// Utils
|
||||||
|
@ -176,7 +186,7 @@ extern "C" {
|
||||||
typedef bool (*ggml_backend_eval_callback)(int node_index, struct ggml_tensor * t1, struct ggml_tensor * t2, void * user_data);
|
typedef bool (*ggml_backend_eval_callback)(int node_index, struct ggml_tensor * t1, struct ggml_tensor * t2, void * user_data);
|
||||||
|
|
||||||
// Compare the output of two backends
|
// Compare the output of two backends
|
||||||
GGML_API void ggml_backend_compare_graph_backend(ggml_backend_t backend1, ggml_backend_t backend2, struct ggml_cgraph * graph, ggml_backend_eval_callback callback, void * user_data);
|
GGML_API bool ggml_backend_compare_graph_backend(ggml_backend_t backend1, ggml_backend_t backend2, struct ggml_cgraph * graph, ggml_backend_eval_callback callback, void * user_data);
|
||||||
|
|
||||||
// Tensor initialization
|
// Tensor initialization
|
||||||
GGML_API void ggml_backend_tensor_alloc(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, void * addr);
|
GGML_API void ggml_backend_tensor_alloc(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, void * addr);
|
||||||
|
|
964
ggml-cuda.cu
964
ggml-cuda.cu
File diff suppressed because it is too large
Load diff
26
ggml-cuda.h
26
ggml-cuda.h
|
@ -27,22 +27,6 @@ GGML_API void * ggml_cuda_host_malloc(size_t size);
|
||||||
GGML_API void ggml_cuda_host_free(void * ptr);
|
GGML_API void ggml_cuda_host_free(void * ptr);
|
||||||
|
|
||||||
GGML_API bool ggml_cuda_can_mul_mat(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst);
|
GGML_API bool ggml_cuda_can_mul_mat(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst);
|
||||||
GGML_API void ggml_cuda_set_tensor_split(const float * tensor_split);
|
|
||||||
GGML_API void ggml_cuda_transform_tensor(void * data, struct ggml_tensor * tensor);
|
|
||||||
GGML_API void ggml_cuda_free_data(struct ggml_tensor * tensor);
|
|
||||||
|
|
||||||
GGML_API void ggml_cuda_assign_buffers(struct ggml_tensor * tensor);
|
|
||||||
GGML_API void ggml_cuda_assign_buffers_no_scratch(struct ggml_tensor * tensor);
|
|
||||||
GGML_API void ggml_cuda_assign_buffers_force_inplace(struct ggml_tensor * tensor);
|
|
||||||
|
|
||||||
GGML_API void ggml_cuda_assign_buffers_no_alloc(struct ggml_tensor * tensor);
|
|
||||||
GGML_API void ggml_cuda_assign_scratch_offset(struct ggml_tensor * tensor, size_t offset);
|
|
||||||
GGML_API void ggml_cuda_copy_to_device(struct ggml_tensor * tensor);
|
|
||||||
|
|
||||||
GGML_API void ggml_cuda_set_main_device(int main_device);
|
|
||||||
GGML_API void ggml_cuda_set_mul_mat_q(bool mul_mat_q);
|
|
||||||
GGML_API void ggml_cuda_set_scratch_size(size_t scratch_size);
|
|
||||||
GGML_API void ggml_cuda_free_scratch(void);
|
|
||||||
GGML_API bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor);
|
GGML_API bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor);
|
||||||
|
|
||||||
GGML_API int ggml_cuda_get_device_count(void);
|
GGML_API int ggml_cuda_get_device_count(void);
|
||||||
|
@ -52,13 +36,17 @@ GGML_API void ggml_cuda_get_device_description(int device, char * description,
|
||||||
GGML_API ggml_backend_t ggml_backend_cuda_init(int device);
|
GGML_API ggml_backend_t ggml_backend_cuda_init(int device);
|
||||||
|
|
||||||
GGML_API bool ggml_backend_is_cuda(ggml_backend_t backend);
|
GGML_API bool ggml_backend_is_cuda(ggml_backend_t backend);
|
||||||
GGML_API int ggml_backend_cuda_get_device(ggml_backend_t backend);
|
|
||||||
|
|
||||||
GGML_API ggml_backend_buffer_type_t ggml_backend_cuda_buffer_type(int device);
|
GGML_API ggml_backend_buffer_type_t ggml_backend_cuda_buffer_type(int device);
|
||||||
|
// split tensor buffer that splits matrices by rows across multiple devices
|
||||||
// pinned host buffer for use with CPU backend for faster copies between CPU and GPU
|
GGML_API ggml_backend_buffer_type_t ggml_backend_cuda_split_buffer_type(const float * tensor_split);
|
||||||
|
// pinned host buffer for use with the CPU backend for faster copies between CPU and GPU
|
||||||
GGML_API ggml_backend_buffer_type_t ggml_backend_cuda_host_buffer_type(void);
|
GGML_API ggml_backend_buffer_type_t ggml_backend_cuda_host_buffer_type(void);
|
||||||
|
|
||||||
|
GGML_API int ggml_backend_cuda_get_device_count(void);
|
||||||
|
GGML_API void ggml_backend_cuda_get_device_description(int device, char * description, size_t description_size);
|
||||||
|
GGML_API void ggml_backend_cuda_get_device_memory(int device, size_t * free, size_t * total);
|
||||||
|
|
||||||
#ifdef __cplusplus
|
#ifdef __cplusplus
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
|
@ -228,6 +228,8 @@ inline static float ggml_lookup_fp16_to_fp32(ggml_fp16_t f) {
|
||||||
#define GGML_HASHTABLE_FULL ((size_t)-1)
|
#define GGML_HASHTABLE_FULL ((size_t)-1)
|
||||||
#define GGML_HASHTABLE_ALREADY_EXISTS ((size_t)-2)
|
#define GGML_HASHTABLE_ALREADY_EXISTS ((size_t)-2)
|
||||||
|
|
||||||
|
struct ggml_hash_set ggml_hash_set_new(size_t size);
|
||||||
|
|
||||||
bool ggml_hash_contains (const struct ggml_hash_set hash_set, struct ggml_tensor * key);
|
bool ggml_hash_contains (const struct ggml_hash_set hash_set, struct ggml_tensor * key);
|
||||||
|
|
||||||
// returns GGML_HASHTABLE_FULL if table is full, otherwise the current index of the key or where it should be inserted
|
// returns GGML_HASHTABLE_FULL if table is full, otherwise the current index of the key or where it should be inserted
|
||||||
|
|
55
ggml-metal.m
55
ggml-metal.m
|
@ -2532,10 +2532,10 @@ static void ggml_backend_metal_free_device(void) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
static void * ggml_backend_metal_buffer_get_base(ggml_backend_buffer_t buffer) {
|
static const char * ggml_backend_metal_buffer_get_name(ggml_backend_buffer_t buffer) {
|
||||||
struct ggml_backend_metal_buffer_context * ctx = (struct ggml_backend_metal_buffer_context *)buffer->context;
|
return "Metal";
|
||||||
|
|
||||||
return ctx->all_data;
|
UNUSED(buffer);
|
||||||
}
|
}
|
||||||
|
|
||||||
static void ggml_backend_metal_buffer_free_buffer(ggml_backend_buffer_t buffer) {
|
static void ggml_backend_metal_buffer_free_buffer(ggml_backend_buffer_t buffer) {
|
||||||
|
@ -2553,6 +2553,12 @@ static void ggml_backend_metal_buffer_free_buffer(ggml_backend_buffer_t buffer)
|
||||||
free(ctx);
|
free(ctx);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static void * ggml_backend_metal_buffer_get_base(ggml_backend_buffer_t buffer) {
|
||||||
|
struct ggml_backend_metal_buffer_context * ctx = (struct ggml_backend_metal_buffer_context *)buffer->context;
|
||||||
|
|
||||||
|
return ctx->all_data;
|
||||||
|
}
|
||||||
|
|
||||||
static void ggml_backend_metal_buffer_set_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
|
static void ggml_backend_metal_buffer_set_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
|
||||||
memcpy((char *)tensor->data + offset, data, size);
|
memcpy((char *)tensor->data + offset, data, size);
|
||||||
|
|
||||||
|
@ -2565,14 +2571,12 @@ static void ggml_backend_metal_buffer_get_tensor(ggml_backend_buffer_t buffer, c
|
||||||
UNUSED(buffer);
|
UNUSED(buffer);
|
||||||
}
|
}
|
||||||
|
|
||||||
static void ggml_backend_metal_buffer_cpy_tensor_from(ggml_backend_buffer_t buffer, struct ggml_tensor * src, struct ggml_tensor * dst) {
|
static bool ggml_backend_metal_buffer_cpy_tensor(ggml_backend_buffer_t buffer, const struct ggml_tensor * src, struct ggml_tensor * dst) {
|
||||||
ggml_backend_tensor_get(src, dst->data, 0, ggml_nbytes(src));
|
if (ggml_backend_buffer_is_host(src->buffer)) {
|
||||||
|
memcpy(dst->data, src->data, ggml_nbytes(src));
|
||||||
UNUSED(buffer);
|
return true;
|
||||||
}
|
}
|
||||||
|
return false;
|
||||||
static void ggml_backend_metal_buffer_cpy_tensor_to(ggml_backend_buffer_t buffer, struct ggml_tensor * src, struct ggml_tensor * dst) {
|
|
||||||
ggml_backend_tensor_set(dst, src->data, 0, ggml_nbytes(src));
|
|
||||||
|
|
||||||
UNUSED(buffer);
|
UNUSED(buffer);
|
||||||
}
|
}
|
||||||
|
@ -2584,18 +2588,25 @@ static void ggml_backend_metal_buffer_clear(ggml_backend_buffer_t buffer, uint8_
|
||||||
}
|
}
|
||||||
|
|
||||||
static struct ggml_backend_buffer_i ggml_backend_metal_buffer_i = {
|
static struct ggml_backend_buffer_i ggml_backend_metal_buffer_i = {
|
||||||
|
/* .get_name = */ ggml_backend_metal_buffer_get_name,
|
||||||
/* .free_buffer = */ ggml_backend_metal_buffer_free_buffer,
|
/* .free_buffer = */ ggml_backend_metal_buffer_free_buffer,
|
||||||
/* .get_base = */ ggml_backend_metal_buffer_get_base,
|
/* .get_base = */ ggml_backend_metal_buffer_get_base,
|
||||||
/* .init_tensor = */ NULL,
|
/* .init_tensor = */ NULL,
|
||||||
/* .set_tensor = */ ggml_backend_metal_buffer_set_tensor,
|
/* .set_tensor = */ ggml_backend_metal_buffer_set_tensor,
|
||||||
/* .get_tensor = */ ggml_backend_metal_buffer_get_tensor,
|
/* .get_tensor = */ ggml_backend_metal_buffer_get_tensor,
|
||||||
/* .cpy_tensor_from = */ ggml_backend_metal_buffer_cpy_tensor_from,
|
/* .cpy_tensor = */ ggml_backend_metal_buffer_cpy_tensor,
|
||||||
/* .cpy_tensor_to = */ ggml_backend_metal_buffer_cpy_tensor_to,
|
|
||||||
/* .clear = */ ggml_backend_metal_buffer_clear,
|
/* .clear = */ ggml_backend_metal_buffer_clear,
|
||||||
|
/* .reset = */ NULL,
|
||||||
};
|
};
|
||||||
|
|
||||||
// default buffer type
|
// default buffer type
|
||||||
|
|
||||||
|
static const char * ggml_backend_metal_buffer_type_get_name(ggml_backend_buffer_type_t buft) {
|
||||||
|
return "Metal";
|
||||||
|
|
||||||
|
UNUSED(buft);
|
||||||
|
}
|
||||||
|
|
||||||
static ggml_backend_buffer_t ggml_backend_metal_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
|
static ggml_backend_buffer_t ggml_backend_metal_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
|
||||||
struct ggml_backend_metal_buffer_context * ctx = malloc(sizeof(struct ggml_backend_metal_buffer_context));
|
struct ggml_backend_metal_buffer_context * ctx = malloc(sizeof(struct ggml_backend_metal_buffer_context));
|
||||||
|
|
||||||
|
@ -2668,6 +2679,7 @@ static bool ggml_backend_metal_buffer_type_is_host(ggml_backend_buffer_type_t bu
|
||||||
ggml_backend_buffer_type_t ggml_backend_metal_buffer_type(void) {
|
ggml_backend_buffer_type_t ggml_backend_metal_buffer_type(void) {
|
||||||
static struct ggml_backend_buffer_type ggml_backend_buffer_type_metal = {
|
static struct ggml_backend_buffer_type ggml_backend_buffer_type_metal = {
|
||||||
/* .iface = */ {
|
/* .iface = */ {
|
||||||
|
/* .get_name = */ ggml_backend_metal_buffer_type_get_name,
|
||||||
/* .alloc_buffer = */ ggml_backend_metal_buffer_type_alloc_buffer,
|
/* .alloc_buffer = */ ggml_backend_metal_buffer_type_alloc_buffer,
|
||||||
/* .get_alignment = */ ggml_backend_metal_buffer_type_get_alignment,
|
/* .get_alignment = */ ggml_backend_metal_buffer_type_get_alignment,
|
||||||
/* .get_alloc_size = */ NULL, // defaults to ggml_nbytes
|
/* .get_alloc_size = */ NULL, // defaults to ggml_nbytes
|
||||||
|
@ -2691,6 +2703,14 @@ ggml_backend_buffer_t ggml_backend_metal_buffer_from_ptr(void * data, size_t siz
|
||||||
ctx->n_buffers = 0;
|
ctx->n_buffers = 0;
|
||||||
|
|
||||||
const size_t size_page = sysconf(_SC_PAGESIZE);
|
const size_t size_page = sysconf(_SC_PAGESIZE);
|
||||||
|
|
||||||
|
// page-align the data ptr
|
||||||
|
{
|
||||||
|
const uintptr_t offs = (uintptr_t) data % size_page;
|
||||||
|
data = (void *) ((char *) data - offs);
|
||||||
|
size += offs;
|
||||||
|
}
|
||||||
|
|
||||||
size_t size_aligned = size;
|
size_t size_aligned = size;
|
||||||
if ((size_aligned % size_page) != 0) {
|
if ((size_aligned % size_page) != 0) {
|
||||||
size_aligned += (size_page - (size_aligned % size_page));
|
size_aligned += (size_page - (size_aligned % size_page));
|
||||||
|
@ -2791,14 +2811,13 @@ static bool ggml_backend_metal_supports_op(ggml_backend_t backend, const struct
|
||||||
return ggml_metal_supports_op(metal_ctx, op);
|
return ggml_metal_supports_op(metal_ctx, op);
|
||||||
}
|
}
|
||||||
|
|
||||||
static struct ggml_backend_i metal_backend_i = {
|
static struct ggml_backend_i ggml_backend_metal_i = {
|
||||||
/* .get_name = */ ggml_backend_metal_name,
|
/* .get_name = */ ggml_backend_metal_name,
|
||||||
/* .free = */ ggml_backend_metal_free,
|
/* .free = */ ggml_backend_metal_free,
|
||||||
/* .get_default_buffer_type = */ ggml_backend_metal_get_default_buffer_type,
|
/* .get_default_buffer_type = */ ggml_backend_metal_get_default_buffer_type,
|
||||||
/* .set_tensor_async = */ NULL,
|
/* .set_tensor_async = */ NULL,
|
||||||
/* .get_tensor_async = */ NULL,
|
/* .get_tensor_async = */ NULL,
|
||||||
/* .cpy_tensor_from_async = */ NULL,
|
/* .cpy_tensor_async = */ NULL,
|
||||||
/* .cpy_tensor_to_async = */ NULL,
|
|
||||||
/* .synchronize = */ NULL,
|
/* .synchronize = */ NULL,
|
||||||
/* .graph_plan_create = */ NULL,
|
/* .graph_plan_create = */ NULL,
|
||||||
/* .graph_plan_free = */ NULL,
|
/* .graph_plan_free = */ NULL,
|
||||||
|
@ -2817,7 +2836,7 @@ ggml_backend_t ggml_backend_metal_init(void) {
|
||||||
ggml_backend_t metal_backend = malloc(sizeof(struct ggml_backend));
|
ggml_backend_t metal_backend = malloc(sizeof(struct ggml_backend));
|
||||||
|
|
||||||
*metal_backend = (struct ggml_backend) {
|
*metal_backend = (struct ggml_backend) {
|
||||||
/* .interface = */ metal_backend_i,
|
/* .interface = */ ggml_backend_metal_i,
|
||||||
/* .context = */ ctx,
|
/* .context = */ ctx,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -2825,7 +2844,7 @@ ggml_backend_t ggml_backend_metal_init(void) {
|
||||||
}
|
}
|
||||||
|
|
||||||
bool ggml_backend_is_metal(ggml_backend_t backend) {
|
bool ggml_backend_is_metal(ggml_backend_t backend) {
|
||||||
return backend->iface.get_name == ggml_backend_metal_name;
|
return backend && backend->iface.get_name == ggml_backend_metal_name;
|
||||||
}
|
}
|
||||||
|
|
||||||
void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) {
|
void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) {
|
||||||
|
|
335
ggml-opencl.cpp
335
ggml-opencl.cpp
|
@ -1,5 +1,6 @@
|
||||||
#include "ggml.h"
|
#include "ggml.h"
|
||||||
#include "ggml-opencl.h"
|
#include "ggml-opencl.h"
|
||||||
|
#include "ggml-backend-impl.h"
|
||||||
|
|
||||||
#include <array>
|
#include <array>
|
||||||
#include <atomic>
|
#include <atomic>
|
||||||
|
@ -10,7 +11,7 @@
|
||||||
#include <sstream>
|
#include <sstream>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#define CL_TARGET_OPENCL_VERSION 110
|
#define CL_TARGET_OPENCL_VERSION 120
|
||||||
#include <clblast.h>
|
#include <clblast.h>
|
||||||
|
|
||||||
#if defined(_MSC_VER)
|
#if defined(_MSC_VER)
|
||||||
|
@ -929,6 +930,12 @@ static cl_program build_program_from_source(cl_context ctx, cl_device_id dev, co
|
||||||
}
|
}
|
||||||
|
|
||||||
void ggml_cl_init(void) {
|
void ggml_cl_init(void) {
|
||||||
|
static bool initialized = false;
|
||||||
|
if (initialized) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
initialized = true;
|
||||||
|
|
||||||
cl_int err;
|
cl_int err;
|
||||||
|
|
||||||
struct cl_device;
|
struct cl_device;
|
||||||
|
@ -1483,8 +1490,8 @@ static void ggml_cl_mul_mat_f32(const ggml_tensor * src0, const ggml_tensor * sr
|
||||||
} else {
|
} else {
|
||||||
d_X = ggml_cl_pool_malloc(sizeof(float) * x_ne, &x_size);
|
d_X = ggml_cl_pool_malloc(sizeof(float) * x_ne, &x_size);
|
||||||
}
|
}
|
||||||
cl_mem d_Y = ggml_cl_pool_malloc(sizeof(float) * y_ne, &y_size);
|
cl_mem d_Y = src1->backend == GGML_BACKEND_GPU ? (cl_mem) src1->extra : ggml_cl_pool_malloc(sizeof(float) * y_ne, &y_size);
|
||||||
cl_mem d_D = ggml_cl_pool_malloc(sizeof(float) * d_ne, &d_size);
|
cl_mem d_D = dst->backend == GGML_BACKEND_GPU ? (cl_mem) dst->extra : ggml_cl_pool_malloc(sizeof(float) * d_ne, &d_size);
|
||||||
|
|
||||||
size_t x_offset = 0;
|
size_t x_offset = 0;
|
||||||
|
|
||||||
|
@ -1501,7 +1508,9 @@ static void ggml_cl_mul_mat_f32(const ggml_tensor * src0, const ggml_tensor * sr
|
||||||
|
|
||||||
for (int64_t i12 = i02 * r2, e12 = i12 + r2; i12 < e12; i12++) {
|
for (int64_t i12 = i02 * r2, e12 = i12 + r2; i12 < e12; i12++) {
|
||||||
// copy src1 to device
|
// copy src1 to device
|
||||||
CL_CHECK(ggml_cl_h2d_tensor_2d(queue, d_Y, 0, src1, i13, i12, NULL));
|
if (src1->backend == GGML_BACKEND_CPU) {
|
||||||
|
CL_CHECK(ggml_cl_h2d_tensor_2d(queue, d_Y, 0, src1, i13, i12, NULL));
|
||||||
|
}
|
||||||
|
|
||||||
CL_CHECK(clFinish(queue));
|
CL_CHECK(clFinish(queue));
|
||||||
|
|
||||||
|
@ -1522,8 +1531,10 @@ static void ggml_cl_mul_mat_f32(const ggml_tensor * src0, const ggml_tensor * sr
|
||||||
}
|
}
|
||||||
|
|
||||||
// copy dst to host
|
// copy dst to host
|
||||||
float * d = (float *) ((char *) dst->data + i12*nb2 + i13*nb3);
|
if (dst->backend == GGML_BACKEND_CPU) {
|
||||||
CL_CHECK(clEnqueueReadBuffer(queue, d_D, true, 0, sizeof(float) * d_ne, d, 1, &ev_sgemm, NULL));
|
float * d = (float *) ((char *) dst->data + i12*nb2 + i13*nb3);
|
||||||
|
CL_CHECK(clEnqueueReadBuffer(queue, d_D, true, 0, sizeof(float) * d_ne, d, 1, &ev_sgemm, NULL));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1532,8 +1543,12 @@ static void ggml_cl_mul_mat_f32(const ggml_tensor * src0, const ggml_tensor * sr
|
||||||
if (src0->backend != GGML_BACKEND_GPU) {
|
if (src0->backend != GGML_BACKEND_GPU) {
|
||||||
ggml_cl_pool_free(d_X, x_size);
|
ggml_cl_pool_free(d_X, x_size);
|
||||||
}
|
}
|
||||||
ggml_cl_pool_free(d_Y, y_size);
|
if (src1->backend != GGML_BACKEND_GPU) {
|
||||||
ggml_cl_pool_free(d_D, d_size);
|
ggml_cl_pool_free(d_Y, y_size);
|
||||||
|
}
|
||||||
|
if (dst->backend != GGML_BACKEND_GPU) {
|
||||||
|
ggml_cl_pool_free(d_D, d_size);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
static void ggml_cl_mul_mat_f16(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, void * wdata, size_t wsize) {
|
static void ggml_cl_mul_mat_f16(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, void * wdata, size_t wsize) {
|
||||||
|
@ -1598,6 +1613,8 @@ static void ggml_cl_mul_mat_f16(const ggml_tensor * src0, const ggml_tensor * sr
|
||||||
CL_CHECK(ggml_cl_h2d_tensor_2d(queue, d_X, 0, src0, i03, i02, NULL));
|
CL_CHECK(ggml_cl_h2d_tensor_2d(queue, d_X, 0, src0, i03, i02, NULL));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// FIXME: convert on device
|
||||||
|
|
||||||
for (int64_t i12 = i02 * r2, e12 = i12 + r2; i12 < e12; i12++) {
|
for (int64_t i12 = i02 * r2, e12 = i12 + r2; i12 < e12; i12++) {
|
||||||
// convert src1 to fp16
|
// convert src1 to fp16
|
||||||
// TODO: use multiple threads
|
// TODO: use multiple threads
|
||||||
|
@ -1643,11 +1660,13 @@ static void ggml_cl_mul_mat_f16(const ggml_tensor * src0, const ggml_tensor * sr
|
||||||
}
|
}
|
||||||
|
|
||||||
// copy dst to host, then convert to float
|
// copy dst to host, then convert to float
|
||||||
CL_CHECK(clEnqueueReadBuffer(queue, d_D, true, 0, sizeof(ggml_fp16_t) * d_ne, tmp, 1, &ev_sgemm, NULL));
|
if (dst->backend == GGML_BACKEND_CPU) {
|
||||||
|
CL_CHECK(clEnqueueReadBuffer(queue, d_D, true, 0, sizeof(ggml_fp16_t) * d_ne, tmp, 1, &ev_sgemm, NULL));
|
||||||
float * d = (float *) ((char *) dst->data + i12*nb2 + i13*nb3);
|
float * d = (float *) ((char *) dst->data + i12*nb2 + i13*nb3);
|
||||||
|
ggml_fp16_to_fp32_row(tmp, d, d_ne);
|
||||||
ggml_fp16_to_fp32_row(tmp, d, d_ne);
|
} else {
|
||||||
|
// FIXME: convert dst to fp32 on device
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1801,7 +1820,7 @@ static void ggml_cl_mul_mat_q_f32(const ggml_tensor * src0, const ggml_tensor *
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
bool ggml_cl_can_mul_mat(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst) {
|
bool ggml_cl_can_mul_mat(const struct ggml_tensor * src0, const struct ggml_tensor * src1, const struct ggml_tensor * dst) {
|
||||||
const int64_t ne10 = src1->ne[0];
|
const int64_t ne10 = src1->ne[0];
|
||||||
|
|
||||||
const int64_t ne0 = dst->ne[0];
|
const int64_t ne0 = dst->ne[0];
|
||||||
|
@ -1895,3 +1914,291 @@ void ggml_cl_transform_tensor(void * data, ggml_tensor * tensor) {
|
||||||
tensor->extra = dst;
|
tensor->extra = dst;
|
||||||
GGML_ASSERT(tensor->backend == GGML_BACKEND_GPU);
|
GGML_ASSERT(tensor->backend == GGML_BACKEND_GPU);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ggml-backend
|
||||||
|
|
||||||
|
// buffer
|
||||||
|
|
||||||
|
struct ggml_backend_opencl_buffer_context {
|
||||||
|
~ggml_backend_opencl_buffer_context() {
|
||||||
|
if (buffer) {
|
||||||
|
clReleaseMemObject(buffer);
|
||||||
|
}
|
||||||
|
for (auto * sub_buffer : sub_buffers) {
|
||||||
|
clReleaseMemObject(sub_buffer);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
cl_mem buffer;
|
||||||
|
std::vector<cl_mem> sub_buffers;
|
||||||
|
};
|
||||||
|
|
||||||
|
static void * const cl_ptr_base = (void *)(uintptr_t) 0x1000;
|
||||||
|
|
||||||
|
static const char * ggml_backend_opencl_buffer_get_name(ggml_backend_buffer_t buffer) {
|
||||||
|
return "OpenCL";
|
||||||
|
|
||||||
|
GGML_UNUSED(buffer);
|
||||||
|
}
|
||||||
|
|
||||||
|
static void ggml_backend_opencl_buffer_free_buffer(ggml_backend_buffer_t buffer) {
|
||||||
|
ggml_backend_opencl_buffer_context * ctx = (ggml_backend_opencl_buffer_context *) buffer->context;
|
||||||
|
delete ctx;
|
||||||
|
}
|
||||||
|
|
||||||
|
static void * ggml_backend_opencl_buffer_get_base(ggml_backend_buffer_t buffer) {
|
||||||
|
return cl_ptr_base;
|
||||||
|
|
||||||
|
GGML_UNUSED(buffer);
|
||||||
|
}
|
||||||
|
|
||||||
|
static void ggml_backend_opencl_buffer_init_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor) {
|
||||||
|
if (tensor->view_src != NULL && tensor->view_offs == 0) {
|
||||||
|
tensor->extra = tensor->view_src->extra;
|
||||||
|
} else {
|
||||||
|
ggml_backend_opencl_buffer_context * ctx = (ggml_backend_opencl_buffer_context *) buffer->context;
|
||||||
|
cl_buffer_region region = {(size_t)((char *)tensor->data - (char *)cl_ptr_base), ggml_nbytes(tensor)};
|
||||||
|
cl_int err;
|
||||||
|
cl_mem sub_buffer = clCreateSubBuffer(ctx->buffer, CL_MEM_READ_WRITE, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err);
|
||||||
|
CL_CHECK(err);
|
||||||
|
ctx->sub_buffers.push_back(sub_buffer);
|
||||||
|
tensor->extra = sub_buffer;
|
||||||
|
}
|
||||||
|
tensor->backend = GGML_BACKEND_GPU;
|
||||||
|
}
|
||||||
|
|
||||||
|
static void ggml_backend_opencl_buffer_set_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
|
||||||
|
cl_mem tensor_buffer = (cl_mem) tensor->extra;
|
||||||
|
CL_CHECK(clEnqueueWriteBuffer(queue, tensor_buffer, true, offset, size, data, 0, NULL, NULL));
|
||||||
|
CL_CHECK(clFinish(queue));
|
||||||
|
|
||||||
|
GGML_UNUSED(buffer);
|
||||||
|
}
|
||||||
|
|
||||||
|
static void ggml_backend_opencl_buffer_get_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * tensor, void * data, size_t offset, size_t size) {
|
||||||
|
cl_mem tensor_buffer = (cl_mem) tensor->extra;
|
||||||
|
CL_CHECK(clEnqueueReadBuffer(queue, tensor_buffer, true, offset, size, data, 0, NULL, NULL));
|
||||||
|
CL_CHECK(clFinish(queue));
|
||||||
|
|
||||||
|
GGML_UNUSED(buffer);
|
||||||
|
}
|
||||||
|
|
||||||
|
static void ggml_backend_opencl_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) {
|
||||||
|
ggml_backend_opencl_buffer_context * ctx = (ggml_backend_opencl_buffer_context *) buffer->context;
|
||||||
|
CL_CHECK(clEnqueueFillBuffer(queue, ctx->buffer, &value, sizeof(value), 0, buffer->size, 0, NULL, NULL));
|
||||||
|
CL_CHECK(clFinish(queue));
|
||||||
|
}
|
||||||
|
|
||||||
|
static void ggml_backend_opencl_buffer_reset(ggml_backend_buffer_t buffer) {
|
||||||
|
ggml_backend_opencl_buffer_context * ctx = (ggml_backend_opencl_buffer_context *) buffer->context;
|
||||||
|
for (auto * sub_buffer : ctx->sub_buffers) {
|
||||||
|
clReleaseMemObject(sub_buffer);
|
||||||
|
}
|
||||||
|
ctx->sub_buffers.clear();
|
||||||
|
}
|
||||||
|
|
||||||
|
static ggml_backend_buffer_i ggml_backend_opencl_buffer_interface = {
|
||||||
|
/* .get_name = */ ggml_backend_opencl_buffer_get_name,
|
||||||
|
/* .free_buffer = */ ggml_backend_opencl_buffer_free_buffer,
|
||||||
|
/* .get_base = */ ggml_backend_opencl_buffer_get_base,
|
||||||
|
/* .init_tensor = */ ggml_backend_opencl_buffer_init_tensor,
|
||||||
|
/* .set_tensor = */ ggml_backend_opencl_buffer_set_tensor,
|
||||||
|
/* .get_tensor = */ ggml_backend_opencl_buffer_get_tensor,
|
||||||
|
/* .cpy_tensor = */ NULL,
|
||||||
|
/* .clear = */ ggml_backend_opencl_buffer_clear,
|
||||||
|
/* .reset = */ ggml_backend_opencl_buffer_reset,
|
||||||
|
};
|
||||||
|
|
||||||
|
// buffer type
|
||||||
|
|
||||||
|
static const char * ggml_backend_opencl_buffer_type_name(ggml_backend_buffer_type_t buffer_type) {
|
||||||
|
return "OpenCL";
|
||||||
|
|
||||||
|
GGML_UNUSED(buffer_type);
|
||||||
|
}
|
||||||
|
|
||||||
|
static ggml_backend_buffer_t ggml_backend_opencl_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buffer_type, size_t size) {
|
||||||
|
ggml_cl_init();
|
||||||
|
|
||||||
|
cl_int err;
|
||||||
|
cl_mem mem = clCreateBuffer(context, CL_MEM_READ_WRITE, size, NULL, &err);
|
||||||
|
if (err != CL_SUCCESS) {
|
||||||
|
fprintf(stderr, "%s: failed to allocate %.2f MiB\n", __func__, size / 1024.0 / 1024.0);
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
ggml_backend_opencl_buffer_context * ctx = new ggml_backend_opencl_buffer_context{mem, {}};
|
||||||
|
|
||||||
|
return ggml_backend_buffer_init(buffer_type, ggml_backend_opencl_buffer_interface, ctx, size);
|
||||||
|
}
|
||||||
|
|
||||||
|
static size_t ggml_backend_opencl_buffer_type_get_alignment(ggml_backend_buffer_type_t buffer_type) {
|
||||||
|
// FIXME: not thread safe, device may not be initialized yet
|
||||||
|
static cl_uint alignment = -1;
|
||||||
|
if (alignment == (cl_uint)-1) {
|
||||||
|
ggml_cl_init();
|
||||||
|
clGetDeviceInfo(device, CL_DEVICE_MEM_BASE_ADDR_ALIGN, sizeof(cl_uint), &alignment, NULL);
|
||||||
|
}
|
||||||
|
return alignment;
|
||||||
|
|
||||||
|
GGML_UNUSED(buffer_type);
|
||||||
|
}
|
||||||
|
|
||||||
|
static bool ggml_backend_opencl_buffer_type_supports_backend(ggml_backend_buffer_type_t buffer_type, ggml_backend_t backend) {
|
||||||
|
//return ggml_backend_is_opencl(backend); // opencl must be used through the cpu backend
|
||||||
|
return ggml_backend_is_cpu(backend);
|
||||||
|
|
||||||
|
GGML_UNUSED(buffer_type);
|
||||||
|
}
|
||||||
|
|
||||||
|
static ggml_backend_buffer_type_i ggml_backend_opencl_buffer_type_interface = {
|
||||||
|
/* .get_name = */ ggml_backend_opencl_buffer_type_name,
|
||||||
|
/* .alloc_buffer = */ ggml_backend_opencl_buffer_type_alloc_buffer,
|
||||||
|
/* .get_alignment = */ ggml_backend_opencl_buffer_type_get_alignment,
|
||||||
|
/* .get_alloc_size = */ NULL,
|
||||||
|
/* .supports_backend = */ ggml_backend_opencl_buffer_type_supports_backend,
|
||||||
|
/* .is_host = */ NULL,
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
|
ggml_backend_buffer_type_t ggml_backend_opencl_buffer_type() {
|
||||||
|
static ggml_backend_buffer_type buffer_type = {
|
||||||
|
/* .iface = */ ggml_backend_opencl_buffer_type_interface,
|
||||||
|
/* .context = */ nullptr,
|
||||||
|
};
|
||||||
|
|
||||||
|
return &buffer_type;
|
||||||
|
}
|
||||||
|
|
||||||
|
#if 0
|
||||||
|
// host buffer type
|
||||||
|
|
||||||
|
static const char * ggml_backend_opencl_host_buffer_type_name(ggml_backend_buffer_type_t buft) {
|
||||||
|
return "CL_Host";
|
||||||
|
|
||||||
|
GGML_UNUSED(buft);
|
||||||
|
}
|
||||||
|
|
||||||
|
static const char * ggml_backend_opencl_host_buffer_name(ggml_backend_buffer_t buffer) {
|
||||||
|
return "CL_Host";
|
||||||
|
|
||||||
|
GGML_UNUSED(buffer);
|
||||||
|
}
|
||||||
|
|
||||||
|
static void ggml_backend_opencl_host_buffer_free_buffer(ggml_backend_buffer_t buffer) {
|
||||||
|
ggml_cl_host_free(buffer->context);
|
||||||
|
}
|
||||||
|
|
||||||
|
static ggml_backend_buffer_t ggml_backend_opencl_host_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
|
||||||
|
void * ptr = ggml_cl_host_malloc(size);
|
||||||
|
|
||||||
|
if (ptr == nullptr) {
|
||||||
|
// fallback to cpu buffer
|
||||||
|
return ggml_backend_buft_alloc_buffer(ggml_backend_cpu_buffer_type(), size);
|
||||||
|
}
|
||||||
|
|
||||||
|
ggml_backend_buffer_t buffer = ggml_backend_cpu_buffer_from_ptr(ptr, size);
|
||||||
|
buffer->buft = buft;
|
||||||
|
buffer->iface.get_name = ggml_backend_opencl_host_buffer_name;
|
||||||
|
buffer->iface.free_buffer = ggml_backend_opencl_host_buffer_free_buffer;
|
||||||
|
|
||||||
|
return buffer;
|
||||||
|
}
|
||||||
|
|
||||||
|
ggml_backend_buffer_type_t ggml_backend_opencl_host_buffer_type() {
|
||||||
|
static struct ggml_backend_buffer_type ggml_backend_opencl_buffer_type_host = {
|
||||||
|
/* .iface = */ {
|
||||||
|
/* .get_name = */ ggml_backend_opencl_host_buffer_type_name,
|
||||||
|
/* .alloc_buffer = */ ggml_backend_opencl_host_buffer_type_alloc_buffer,
|
||||||
|
/* .get_alignment = */ ggml_backend_cpu_buffer_type()->iface.get_alignment,
|
||||||
|
/* .get_alloc_size = */ ggml_backend_cpu_buffer_type()->iface.get_alloc_size,
|
||||||
|
/* .supports_backend = */ ggml_backend_cpu_buffer_type()->iface.supports_backend,
|
||||||
|
/* .is_host = */ ggml_backend_cpu_buffer_type()->iface.is_host,
|
||||||
|
},
|
||||||
|
/* .context = */ nullptr,
|
||||||
|
};
|
||||||
|
|
||||||
|
return &ggml_backend_opencl_buffer_type_host;
|
||||||
|
}
|
||||||
|
|
||||||
|
// backend
|
||||||
|
|
||||||
|
static const char * ggml_backend_opencl_name(ggml_backend_t backend) {
|
||||||
|
return "OpenCL";
|
||||||
|
|
||||||
|
GGML_UNUSED(backend);
|
||||||
|
}
|
||||||
|
|
||||||
|
static void ggml_backend_opencl_free(ggml_backend_t backend) {
|
||||||
|
GGML_UNUSED(backend);
|
||||||
|
}
|
||||||
|
|
||||||
|
static ggml_backend_buffer_type_t ggml_backend_opencl_get_default_buffer_type(ggml_backend_t backend) {
|
||||||
|
return ggml_backend_opencl_buffer_type();
|
||||||
|
|
||||||
|
GGML_UNUSED(backend);
|
||||||
|
}
|
||||||
|
|
||||||
|
static bool ggml_backend_opencl_graph_compute(ggml_backend_t backend, ggml_cgraph * graph) {
|
||||||
|
for (int i = 0; i < graph->n_nodes; ++i) {
|
||||||
|
ggml_tensor * node = graph->nodes[i];
|
||||||
|
switch (node->op) {
|
||||||
|
case GGML_OP_MUL_MAT:
|
||||||
|
ggml_cl_mul_mat(node->src[0], node->src[1], node, nullptr, 0);
|
||||||
|
break;
|
||||||
|
case GGML_OP_MUL:
|
||||||
|
ggml_cl_mul(node->src[0], node->src[1], node);
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
GGML_ASSERT(false);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return true;
|
||||||
|
|
||||||
|
GGML_UNUSED(backend);
|
||||||
|
}
|
||||||
|
|
||||||
|
static bool ggml_backend_opencl_supports_op(ggml_backend_t backend, const ggml_tensor * op) {
|
||||||
|
switch (op->op) {
|
||||||
|
case GGML_OP_MUL_MAT:
|
||||||
|
return ggml_cl_can_mul_mat(op->src[0], op->src[1], op);
|
||||||
|
case GGML_OP_MUL:
|
||||||
|
// return ggml_can_repeat_rows(op->src[1], op->src[0]);
|
||||||
|
return true;
|
||||||
|
default:
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
GGML_UNUSED(backend);
|
||||||
|
}
|
||||||
|
|
||||||
|
static ggml_backend_i opencl_backend_i = {
|
||||||
|
/* .get_name = */ ggml_backend_opencl_name,
|
||||||
|
/* .free = */ ggml_backend_opencl_free,
|
||||||
|
/* .get_default_buffer_type = */ ggml_backend_opencl_get_default_buffer_type,
|
||||||
|
/* .set_tensor_async = */ NULL,
|
||||||
|
/* .get_tensor_async = */ NULL,
|
||||||
|
/* .cpy_tensor_from_async = */ NULL,
|
||||||
|
/* .cpy_tensor_to_async = */ NULL,
|
||||||
|
/* .synchronize = */ NULL,
|
||||||
|
/* .graph_plan_create = */ NULL,
|
||||||
|
/* .graph_plan_free = */ NULL,
|
||||||
|
/* .graph_plan_compute = */ NULL,
|
||||||
|
/* .graph_compute = */ ggml_backend_opencl_graph_compute,
|
||||||
|
/* .supports_op = */ ggml_backend_opencl_supports_op,
|
||||||
|
};
|
||||||
|
|
||||||
|
ggml_backend_t ggml_backend_opencl_init() {
|
||||||
|
ggml_backend_t backend = new ggml_backend {
|
||||||
|
/* .interface = */ opencl_backend_i,
|
||||||
|
/* .context = */ nullptr
|
||||||
|
};
|
||||||
|
|
||||||
|
return backend;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool ggml_backend_is_opencl(ggml_backend_t backend) {
|
||||||
|
return backend && backend->iface.get_name == ggml_backend_opencl_name;
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include "ggml.h"
|
#include "ggml.h"
|
||||||
|
#include "ggml-backend.h"
|
||||||
|
|
||||||
#ifdef __cplusplus
|
#ifdef __cplusplus
|
||||||
extern "C" {
|
extern "C" {
|
||||||
|
@ -9,17 +10,26 @@ extern "C" {
|
||||||
GGML_API void ggml_cl_init(void);
|
GGML_API void ggml_cl_init(void);
|
||||||
|
|
||||||
GGML_API void ggml_cl_mul(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst);
|
GGML_API void ggml_cl_mul(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst);
|
||||||
GGML_API bool ggml_cl_can_mul_mat(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst);
|
GGML_API bool ggml_cl_can_mul_mat(const struct ggml_tensor * src0, const struct ggml_tensor * src1, const struct ggml_tensor * dst);
|
||||||
GGML_API size_t ggml_cl_mul_mat_get_wsize(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst);
|
GGML_API size_t ggml_cl_mul_mat_get_wsize(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst);
|
||||||
GGML_API void ggml_cl_mul_mat(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst, void * wdata, size_t wsize);
|
GGML_API void ggml_cl_mul_mat(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst, void * wdata, size_t wsize);
|
||||||
|
|
||||||
GGML_API void * ggml_cl_host_malloc(size_t size);
|
// GGML_API void * ggml_cl_host_malloc(size_t size);
|
||||||
GGML_API void ggml_cl_host_free(void * ptr);
|
// GGML_API void ggml_cl_host_free(void * ptr);
|
||||||
|
|
||||||
GGML_API void ggml_cl_free_data(const struct ggml_tensor* tensor);
|
GGML_API void ggml_cl_free_data(const struct ggml_tensor* tensor);
|
||||||
|
|
||||||
GGML_API void ggml_cl_transform_tensor(void * data, struct ggml_tensor * tensor);
|
GGML_API void ggml_cl_transform_tensor(void * data, struct ggml_tensor * tensor);
|
||||||
|
|
||||||
|
// backend API
|
||||||
|
|
||||||
|
// GGML_API ggml_backend_t ggml_backend_opencl_init(void);
|
||||||
|
|
||||||
|
// GGML_API bool ggml_backend_is_opencl(ggml_backend_t backend);
|
||||||
|
|
||||||
|
GGML_API ggml_backend_buffer_type_t ggml_backend_opencl_buffer_type(void);
|
||||||
|
// GGML_API ggml_backend_buffer_type_t ggml_backend_opencl_host_buffer_type(void);
|
||||||
|
|
||||||
#ifdef __cplusplus
|
#ifdef __cplusplus
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
|
@ -272,10 +272,13 @@ static inline float hsum_float_4x4(const __m128 a, const __m128 b, const __m128
|
||||||
|
|
||||||
// vaddvq_s16
|
// vaddvq_s16
|
||||||
// vpaddq_s16
|
// vpaddq_s16
|
||||||
|
// vpaddq_s32
|
||||||
// vaddvq_s32
|
// vaddvq_s32
|
||||||
// vaddvq_f32
|
// vaddvq_f32
|
||||||
// vmaxvq_f32
|
// vmaxvq_f32
|
||||||
// vcvtnq_s32_f32
|
// vcvtnq_s32_f32
|
||||||
|
// vzip1_u8
|
||||||
|
// vzip2_u8
|
||||||
|
|
||||||
inline static int32_t vaddvq_s16(int16x8_t v) {
|
inline static int32_t vaddvq_s16(int16x8_t v) {
|
||||||
return
|
return
|
||||||
|
@ -291,6 +294,12 @@ inline static int16x8_t vpaddq_s16(int16x8_t a, int16x8_t b) {
|
||||||
return vcombine_s16(a0, b0);
|
return vcombine_s16(a0, b0);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
inline static int32x4_t vpaddq_s32(int32x4_t a, int32x4_t b) {
|
||||||
|
int32x2_t a0 = vpadd_s32(vget_low_s32(a), vget_high_s32(a));
|
||||||
|
int32x2_t b0 = vpadd_s32(vget_low_s32(b), vget_high_s32(b));
|
||||||
|
return vcombine_s32(a0, b0);
|
||||||
|
}
|
||||||
|
|
||||||
inline static int32_t vaddvq_s32(int32x4_t v) {
|
inline static int32_t vaddvq_s32(int32x4_t v) {
|
||||||
return vgetq_lane_s32(v, 0) + vgetq_lane_s32(v, 1) + vgetq_lane_s32(v, 2) + vgetq_lane_s32(v, 3);
|
return vgetq_lane_s32(v, 0) + vgetq_lane_s32(v, 1) + vgetq_lane_s32(v, 2) + vgetq_lane_s32(v, 3);
|
||||||
}
|
}
|
||||||
|
@ -316,6 +325,28 @@ inline static int32x4_t vcvtnq_s32_f32(float32x4_t v) {
|
||||||
return res;
|
return res;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
inline static uint8x8_t vzip1_u8(uint8x8_t a, uint8x8_t b) {
|
||||||
|
uint8x8_t res;
|
||||||
|
|
||||||
|
res[0] = a[0]; res[1] = b[0];
|
||||||
|
res[2] = a[1]; res[3] = b[1];
|
||||||
|
res[4] = a[2]; res[5] = b[2];
|
||||||
|
res[6] = a[3]; res[7] = b[3];
|
||||||
|
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
|
inline static uint8x8_t vzip2_u8(uint8x8_t a, uint8x8_t b) {
|
||||||
|
uint8x8_t res;
|
||||||
|
|
||||||
|
res[0] = a[4]; res[1] = b[4];
|
||||||
|
res[2] = a[5]; res[3] = b[5];
|
||||||
|
res[4] = a[6]; res[5] = b[6];
|
||||||
|
res[6] = a[7]; res[7] = b[7];
|
||||||
|
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
// vld1q_s16_x2
|
// vld1q_s16_x2
|
||||||
// vld1q_u8_x2
|
// vld1q_u8_x2
|
||||||
// vld1q_u8_x4
|
// vld1q_u8_x4
|
||||||
|
@ -7554,9 +7585,9 @@ void ggml_vec_dot_iq2_xs_q8_K(const int n, float * restrict s, const void * rest
|
||||||
|
|
||||||
const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs;
|
const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs;
|
||||||
|
|
||||||
int8x16x4_t q2u;
|
ggml_int8x16x4_t q2u;
|
||||||
int8x16x4_t q2s;
|
ggml_int8x16x4_t q2s;
|
||||||
int8x16x4_t q8b;
|
ggml_int8x16x4_t q8b;
|
||||||
|
|
||||||
int32x4x4_t scales32;
|
int32x4x4_t scales32;
|
||||||
|
|
||||||
|
@ -7578,7 +7609,7 @@ void ggml_vec_dot_iq2_xs_q8_K(const int n, float * restrict s, const void * rest
|
||||||
scales32.val[3] = vreinterpretq_s32_u32(vmovl_u16(vget_high_u16(scales2)));
|
scales32.val[3] = vreinterpretq_s32_u32(vmovl_u16(vget_high_u16(scales2)));
|
||||||
int32x4_t sumi = vdupq_n_s32(0);
|
int32x4_t sumi = vdupq_n_s32(0);
|
||||||
for (int ib64 = 0; ib64 < QK_K/64; ++ib64) {
|
for (int ib64 = 0; ib64 < QK_K/64; ++ib64) {
|
||||||
q8b = vld1q_s8_x4(q8); q8 += 64;
|
q8b = ggml_vld1q_s8_x4(q8); q8 += 64;
|
||||||
q2u.val[0] = vcombine_s8(vld1_s8((const void *)(iq2xs_grid + (q2[0] & 511))), vld1_s8((const void *)(iq2xs_grid + (q2[1] & 511))));
|
q2u.val[0] = vcombine_s8(vld1_s8((const void *)(iq2xs_grid + (q2[0] & 511))), vld1_s8((const void *)(iq2xs_grid + (q2[1] & 511))));
|
||||||
q2u.val[1] = vcombine_s8(vld1_s8((const void *)(iq2xs_grid + (q2[2] & 511))), vld1_s8((const void *)(iq2xs_grid + (q2[3] & 511))));
|
q2u.val[1] = vcombine_s8(vld1_s8((const void *)(iq2xs_grid + (q2[2] & 511))), vld1_s8((const void *)(iq2xs_grid + (q2[3] & 511))));
|
||||||
q2u.val[2] = vcombine_s8(vld1_s8((const void *)(iq2xs_grid + (q2[4] & 511))), vld1_s8((const void *)(iq2xs_grid + (q2[5] & 511))));
|
q2u.val[2] = vcombine_s8(vld1_s8((const void *)(iq2xs_grid + (q2[4] & 511))), vld1_s8((const void *)(iq2xs_grid + (q2[5] & 511))));
|
||||||
|
|
30
ggml.c
30
ggml.c
|
@ -2354,6 +2354,10 @@ struct ggml_context * ggml_init(struct ggml_init_params params) {
|
||||||
}
|
}
|
||||||
|
|
||||||
void ggml_free(struct ggml_context * ctx) {
|
void ggml_free(struct ggml_context * ctx) {
|
||||||
|
if (ctx == NULL) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
// make this function thread safe
|
// make this function thread safe
|
||||||
ggml_critical_section_start();
|
ggml_critical_section_start();
|
||||||
|
|
||||||
|
@ -4362,6 +4366,23 @@ struct ggml_tensor * ggml_cpy(
|
||||||
return ggml_cpy_impl(ctx, a, b);
|
return ggml_cpy_impl(ctx, a, b);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
struct ggml_tensor * ggml_cast(
|
||||||
|
struct ggml_context * ctx,
|
||||||
|
struct ggml_tensor * a,
|
||||||
|
enum ggml_type type) {
|
||||||
|
bool is_node = false;
|
||||||
|
|
||||||
|
struct ggml_tensor * result = ggml_new_tensor(ctx, type, GGML_MAX_DIMS, a->ne);
|
||||||
|
ggml_format_name(result, "%s (copy)", a->name);
|
||||||
|
|
||||||
|
result->op = GGML_OP_CPY;
|
||||||
|
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
|
||||||
|
result->src[0] = a;
|
||||||
|
result->src[1] = result;
|
||||||
|
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
// ggml_cont
|
// ggml_cont
|
||||||
|
|
||||||
static struct ggml_tensor * ggml_cont_impl(
|
static struct ggml_tensor * ggml_cont_impl(
|
||||||
|
@ -14871,7 +14892,7 @@ size_t ggml_hash_find_or_insert(struct ggml_hash_set hash_set, struct ggml_tenso
|
||||||
return i;
|
return i;
|
||||||
}
|
}
|
||||||
|
|
||||||
static struct ggml_hash_set ggml_hash_set_new(size_t size) {
|
struct ggml_hash_set ggml_hash_set_new(size_t size) {
|
||||||
size = ggml_hash_size(size);
|
size = ggml_hash_size(size);
|
||||||
struct ggml_hash_set result;
|
struct ggml_hash_set result;
|
||||||
result.size = size;
|
result.size = size;
|
||||||
|
@ -16620,7 +16641,7 @@ static thread_ret_t ggml_graph_compute_thread(void * data) {
|
||||||
return GGML_EXIT_SUCCESS;
|
return GGML_EXIT_SUCCESS;
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) {
|
struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threads) {
|
||||||
if (n_threads <= 0) {
|
if (n_threads <= 0) {
|
||||||
n_threads = GGML_DEFAULT_N_THREADS;
|
n_threads = GGML_DEFAULT_N_THREADS;
|
||||||
}
|
}
|
||||||
|
@ -16682,14 +16703,15 @@ struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) {
|
||||||
} break;
|
} break;
|
||||||
case GGML_OP_MUL_MAT_ID:
|
case GGML_OP_MUL_MAT_ID:
|
||||||
{
|
{
|
||||||
|
cur = 0;
|
||||||
const struct ggml_tensor * src0 = node->src[2];
|
const struct ggml_tensor * src0 = node->src[2];
|
||||||
const struct ggml_tensor * src1 = node->src[1];
|
const struct ggml_tensor * src1 = node->src[1];
|
||||||
const enum ggml_type vec_dot_type = type_traits[src0->type].vec_dot_type;
|
const enum ggml_type vec_dot_type = type_traits[src0->type].vec_dot_type;
|
||||||
if (src1->type != vec_dot_type) {
|
if (src1->type != vec_dot_type) {
|
||||||
cur = ggml_row_size(vec_dot_type, ggml_nelements(src1));
|
cur += ggml_row_size(vec_dot_type, ggml_nelements(src1));
|
||||||
}
|
}
|
||||||
const int n_as = ggml_get_op_params_i32(node, 1);
|
const int n_as = ggml_get_op_params_i32(node, 1);
|
||||||
cur = GGML_PAD(cur, sizeof(int64_t)); // align
|
cur += GGML_PAD(cur, sizeof(int64_t)); // align
|
||||||
cur += n_as * sizeof(int64_t); // matrix_row_counts
|
cur += n_as * sizeof(int64_t); // matrix_row_counts
|
||||||
cur += n_as * src1->ne[1] * sizeof(int64_t); // matrix_rows
|
cur += n_as * src1->ne[1] * sizeof(int64_t); // matrix_rows
|
||||||
} break;
|
} break;
|
||||||
|
|
9
ggml.h
9
ggml.h
|
@ -1165,6 +1165,11 @@ extern "C" {
|
||||||
struct ggml_tensor * a,
|
struct ggml_tensor * a,
|
||||||
struct ggml_tensor * b);
|
struct ggml_tensor * b);
|
||||||
|
|
||||||
|
GGML_API struct ggml_tensor * ggml_cast(
|
||||||
|
struct ggml_context * ctx,
|
||||||
|
struct ggml_tensor * a,
|
||||||
|
enum ggml_type type);
|
||||||
|
|
||||||
// make contiguous
|
// make contiguous
|
||||||
GGML_API struct ggml_tensor * ggml_cont(
|
GGML_API struct ggml_tensor * ggml_cont(
|
||||||
struct ggml_context * ctx,
|
struct ggml_context * ctx,
|
||||||
|
@ -1842,8 +1847,8 @@ extern "C" {
|
||||||
|
|
||||||
// ggml_graph_plan() has to be called before ggml_graph_compute()
|
// ggml_graph_plan() has to be called before ggml_graph_compute()
|
||||||
// when plan.work_size > 0, caller must allocate memory for plan.work_data
|
// when plan.work_size > 0, caller must allocate memory for plan.work_data
|
||||||
GGML_API struct ggml_cplan ggml_graph_plan (struct ggml_cgraph * cgraph, int n_threads /*= GGML_DEFAULT_N_THREADS*/);
|
GGML_API struct ggml_cplan ggml_graph_plan (const struct ggml_cgraph * cgraph, int n_threads /*= GGML_DEFAULT_N_THREADS*/);
|
||||||
GGML_API int ggml_graph_compute(struct ggml_cgraph * cgraph, struct ggml_cplan * cplan);
|
GGML_API int ggml_graph_compute( struct ggml_cgraph * cgraph, struct ggml_cplan * cplan);
|
||||||
|
|
||||||
// same as ggml_graph_compute() but the work data is allocated as a part of the context
|
// same as ggml_graph_compute() but the work data is allocated as a part of the context
|
||||||
// note: the drawback of this API is that you must have ensured that the context has enough memory for the work data
|
// note: the drawback of this API is that you must have ensured that the context has enough memory for the work data
|
||||||
|
|
|
@ -389,6 +389,9 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
|
||||||
MODEL_TENSOR.OUTPUT,
|
MODEL_TENSOR.OUTPUT,
|
||||||
MODEL_TENSOR.ATTN_NORM,
|
MODEL_TENSOR.ATTN_NORM,
|
||||||
MODEL_TENSOR.ATTN_QKV,
|
MODEL_TENSOR.ATTN_QKV,
|
||||||
|
MODEL_TENSOR.ATTN_Q,
|
||||||
|
MODEL_TENSOR.ATTN_K,
|
||||||
|
MODEL_TENSOR.ATTN_V,
|
||||||
MODEL_TENSOR.ATTN_OUT,
|
MODEL_TENSOR.ATTN_OUT,
|
||||||
MODEL_TENSOR.FFN_NORM,
|
MODEL_TENSOR.FFN_NORM,
|
||||||
MODEL_TENSOR.FFN_DOWN,
|
MODEL_TENSOR.FFN_DOWN,
|
||||||
|
|
|
@ -191,6 +191,7 @@ class TensorNameMap:
|
||||||
"transformer.h.{bid}.mlp.w1", # qwen
|
"transformer.h.{bid}.mlp.w1", # qwen
|
||||||
"h.{bid}.mlp.c_fc", # gpt2
|
"h.{bid}.mlp.c_fc", # gpt2
|
||||||
"transformer.h.{bid}.mlp.fc1", # phi2
|
"transformer.h.{bid}.mlp.fc1", # phi2
|
||||||
|
"model.layers.{bid}.mlp.fc1", # phi2
|
||||||
"model.layers.layers.{bid}.mlp.up_proj", # plamo
|
"model.layers.layers.{bid}.mlp.up_proj", # plamo
|
||||||
),
|
),
|
||||||
|
|
||||||
|
@ -232,6 +233,7 @@ class TensorNameMap:
|
||||||
"model.layers.{bid}.mlp.dense_4h_to_h", # persimmon
|
"model.layers.{bid}.mlp.dense_4h_to_h", # persimmon
|
||||||
"h.{bid}.mlp.c_proj", # gpt2
|
"h.{bid}.mlp.c_proj", # gpt2
|
||||||
"transformer.h.{bid}.mlp.fc2", # phi2
|
"transformer.h.{bid}.mlp.fc2", # phi2
|
||||||
|
"model.layers.{bid}.mlp.fc2", # phi2
|
||||||
"model.layers.layers.{bid}.mlp.down_proj", # plamo
|
"model.layers.layers.{bid}.mlp.down_proj", # plamo
|
||||||
),
|
),
|
||||||
|
|
||||||
|
|
18
llama.h
18
llama.h
|
@ -118,6 +118,12 @@ extern "C" {
|
||||||
LLAMA_ROPE_SCALING_MAX_VALUE = LLAMA_ROPE_SCALING_YARN,
|
LLAMA_ROPE_SCALING_MAX_VALUE = LLAMA_ROPE_SCALING_YARN,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
enum llama_split_mode {
|
||||||
|
LLAMA_SPLIT_NONE = 0, // single GPU
|
||||||
|
LLAMA_SPLIT_LAYER = 1, // split layers and KV across GPUs
|
||||||
|
LLAMA_SPLIT_ROW = 2, // split rows across GPUs
|
||||||
|
};
|
||||||
|
|
||||||
typedef struct llama_token_data {
|
typedef struct llama_token_data {
|
||||||
llama_token id; // token id
|
llama_token id; // token id
|
||||||
float logit; // log-odds of the token
|
float logit; // log-odds of the token
|
||||||
|
@ -180,8 +186,16 @@ extern "C" {
|
||||||
|
|
||||||
struct llama_model_params {
|
struct llama_model_params {
|
||||||
int32_t n_gpu_layers; // number of layers to store in VRAM
|
int32_t n_gpu_layers; // number of layers to store in VRAM
|
||||||
int32_t main_gpu; // the GPU that is used for scratch and small tensors
|
enum llama_split_mode split_mode; // how to split the model across multiple GPUs
|
||||||
const float * tensor_split; // how to split layers across multiple GPUs (size: LLAMA_MAX_DEVICES)
|
|
||||||
|
// main_gpu interpretation depends on split_mode:
|
||||||
|
// LLAMA_SPLIT_NONE: the GPU that is used for the entire model
|
||||||
|
// LLAMA_SPLIT_ROW: the GPU that is used for small tensors and intermediate results
|
||||||
|
// LLAMA_SPLIT_LAYER: ignored
|
||||||
|
int32_t main_gpu;
|
||||||
|
|
||||||
|
// proportion of the model (layers or rows) to offload to each GPU, size: LLAMA_MAX_DEVICES
|
||||||
|
const float * tensor_split;
|
||||||
|
|
||||||
// Called with a progress value between 0.0 and 1.0. Pass NULL to disable.
|
// Called with a progress value between 0.0 and 1.0. Pass NULL to disable.
|
||||||
// If the provided progress_callback returns true, model loading continues.
|
// If the provided progress_callback returns true, model loading continues.
|
||||||
|
|
|
@ -10,15 +10,15 @@ import sqlite3
|
||||||
try:
|
try:
|
||||||
import git
|
import git
|
||||||
from tabulate import tabulate
|
from tabulate import tabulate
|
||||||
except ImportError:
|
except ImportError as e:
|
||||||
print("ERROR: the following Python libraries are required: GitPython, tabulate.")
|
print("ERROR: the following Python libraries are required: GitPython, tabulate.")
|
||||||
sys.exit(1)
|
raise e
|
||||||
|
|
||||||
# Properties by which to differentiate results per commit:
|
# Properties by which to differentiate results per commit:
|
||||||
KEY_PROPERTIES = [
|
KEY_PROPERTIES = [
|
||||||
"cuda", "opencl", "metal", "gpu_blas", "blas", "cpu_info", "gpu_info", "model_filename",
|
"cpu_info", "gpu_info", "n_gpu_layers", "main_gpu", "cuda", "opencl", "metal", "gpu_blas",
|
||||||
"model_type", "model_size", "model_n_params", "n_batch", "n_threads", "type_k", "type_v",
|
"blas", "model_filename", "model_type", "model_size", "model_n_params", "n_batch", "n_threads",
|
||||||
"n_gpu_layers", "main_gpu", "no_kv_offload", "mul_mat_q", "tensor_split", "n_prompt", "n_gen"
|
"type_k", "type_v", "no_kv_offload", "mul_mat_q", "tensor_split", "n_prompt", "n_gen"
|
||||||
]
|
]
|
||||||
|
|
||||||
# Properties that are boolean and are converted to Yes/No for the table:
|
# Properties that are boolean and are converted to Yes/No for the table:
|
||||||
|
@ -37,6 +37,7 @@ PRETTY_NAMES = {
|
||||||
DEFAULT_SHOW = ["model_type"] # Always show these properties by default.
|
DEFAULT_SHOW = ["model_type"] # Always show these properties by default.
|
||||||
DEFAULT_HIDE = ["model_filename"] # Always hide these properties by default.
|
DEFAULT_HIDE = ["model_filename"] # Always hide these properties by default.
|
||||||
GPU_NAME_STRIP = ["NVIDIA GeForce ", "Tesla ", "AMD Radeon "] # Strip prefixes for smaller tables.
|
GPU_NAME_STRIP = ["NVIDIA GeForce ", "Tesla ", "AMD Radeon "] # Strip prefixes for smaller tables.
|
||||||
|
MODEL_SUFFIX_REPLACE = {" - Small": "_S", " - Medium": "_M", " - Large": "_L"}
|
||||||
|
|
||||||
DESCRIPTION = """Creates tables from llama-bench data written to an SQLite database. Example usage (Linux):
|
DESCRIPTION = """Creates tables from llama-bench data written to an SQLite database. Example usage (Linux):
|
||||||
|
|
||||||
|
@ -308,8 +309,13 @@ else:
|
||||||
if gpu_blas and "gpu_info" not in properties_different:
|
if gpu_blas and "gpu_info" not in properties_different:
|
||||||
show.append("gpu_info")
|
show.append("gpu_info")
|
||||||
|
|
||||||
show += DEFAULT_SHOW
|
|
||||||
show += properties_different
|
show += properties_different
|
||||||
|
|
||||||
|
index_default = 0
|
||||||
|
for prop in ["cpu_info", "gpu_info", "n_gpu_layers", "main_gpu"]:
|
||||||
|
if prop in show:
|
||||||
|
index_default += 1
|
||||||
|
show = show[:index_default] + DEFAULT_SHOW + show[index_default:]
|
||||||
for prop in DEFAULT_HIDE:
|
for prop in DEFAULT_HIDE:
|
||||||
try:
|
try:
|
||||||
show.remove(prop)
|
show.remove(prop)
|
||||||
|
@ -334,6 +340,12 @@ for bool_property in BOOL_PROPERTIES:
|
||||||
for row_table in table:
|
for row_table in table:
|
||||||
row_table[ip] = "Yes" if int(row_table[ip]) == 1 else "No"
|
row_table[ip] = "Yes" if int(row_table[ip]) == 1 else "No"
|
||||||
|
|
||||||
|
if "model_type" in show:
|
||||||
|
ip = show.index("model_type")
|
||||||
|
for (old, new) in MODEL_SUFFIX_REPLACE.items():
|
||||||
|
for row_table in table:
|
||||||
|
row_table[ip] = row_table[ip].replace(old, new)
|
||||||
|
|
||||||
if "model_size" in show:
|
if "model_size" in show:
|
||||||
ip = show.index("model_size")
|
ip = show.index("model_size")
|
||||||
for row_table in table:
|
for row_table in table:
|
||||||
|
@ -341,10 +353,16 @@ if "model_size" in show:
|
||||||
|
|
||||||
if "gpu_info" in show:
|
if "gpu_info" in show:
|
||||||
ip = show.index("gpu_info")
|
ip = show.index("gpu_info")
|
||||||
for gns in GPU_NAME_STRIP:
|
for row_table in table:
|
||||||
for row_table in table:
|
for gns in GPU_NAME_STRIP:
|
||||||
row_table[ip] = row_table[ip].replace(gns, "")
|
row_table[ip] = row_table[ip].replace(gns, "")
|
||||||
|
|
||||||
|
gpu_names = row_table[ip].split("/")
|
||||||
|
num_gpus = len(gpu_names)
|
||||||
|
all_names_the_same = len(set(gpu_names)) == 1
|
||||||
|
if len(gpu_names) >= 2 and all_names_the_same:
|
||||||
|
row_table[ip] = f"{num_gpus}x {gpu_names[0]}"
|
||||||
|
|
||||||
headers = [PRETTY_NAMES[p] for p in show]
|
headers = [PRETTY_NAMES[p] for p in show]
|
||||||
headers += ["Test", f"t/s {name_baseline}", f"t/s {name_compare}", "Speedup"]
|
headers += ["Test", f"t/s {name_baseline}", f"t/s {name_compare}", "Speedup"]
|
||||||
|
|
||||||
|
|
|
@ -1 +1 @@
|
||||||
979cc23b345006504cfc1f67c0fdf627805e3319
|
400c07f00508e6f60fb25405444b5669c365b0a9
|
||||||
|
|
|
@ -376,6 +376,11 @@ struct test_case {
|
||||||
|
|
||||||
// allocate
|
// allocate
|
||||||
ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors(ctx, backend1);
|
ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors(ctx, backend1);
|
||||||
|
if (buf == NULL) {
|
||||||
|
printf("failed to allocate tensors [%s] ", ggml_backend_name(backend1));
|
||||||
|
ggml_free(ctx);
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
// build graph
|
// build graph
|
||||||
ggml_build_forward_expand(gf, out);
|
ggml_build_forward_expand(gf, out);
|
||||||
|
@ -463,19 +468,23 @@ struct test_case {
|
||||||
GGML_UNUSED(index);
|
GGML_UNUSED(index);
|
||||||
};
|
};
|
||||||
|
|
||||||
ggml_backend_compare_graph_backend(backend1, backend2, gf, callback, &ud);
|
const bool cmp_ok = ggml_backend_compare_graph_backend(backend1, backend2, gf, callback, &ud);
|
||||||
|
|
||||||
if (ud.ok) {
|
if (!cmp_ok) {
|
||||||
printf("\033[1;32mOK\033[0m\n");
|
printf("compare failed ");
|
||||||
} else {
|
|
||||||
printf("\033[1;31mFAIL\033[0m\n");
|
|
||||||
}
|
}
|
||||||
|
|
||||||
ggml_backend_buffer_free(buf);
|
ggml_backend_buffer_free(buf);
|
||||||
|
|
||||||
ggml_free(ctx);
|
ggml_free(ctx);
|
||||||
|
|
||||||
return ud.ok;
|
if (ud.ok && cmp_ok) {
|
||||||
|
printf("\033[1;32mOK\033[0m\n");
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
printf("\033[1;31mFAIL\033[0m\n");
|
||||||
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool eval_perf(ggml_backend_t backend, const char * op_name) {
|
bool eval_perf(ggml_backend_t backend, const char * op_name) {
|
||||||
|
@ -519,6 +528,11 @@ struct test_case {
|
||||||
|
|
||||||
// allocate
|
// allocate
|
||||||
ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors(ctx, backend);
|
ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors(ctx, backend);
|
||||||
|
if (buf == NULL) {
|
||||||
|
printf("failed to allocate tensors\n");
|
||||||
|
ggml_free(ctx);
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
// randomize tensors
|
// randomize tensors
|
||||||
initialize_tensors(ctx);
|
initialize_tensors(ctx);
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue