Merge branch 'master' into feat-override-metadata

This commit is contained in:
KerfuffleV2 2023-11-18 02:27:56 -07:00
commit 2147421904
32 changed files with 457 additions and 118 deletions

View file

@ -574,8 +574,12 @@ elseif (${CMAKE_SYSTEM_PROCESSOR} MATCHES "^(x86_64|i686|AMD64)$" OR "${CMAKE_GE
endif() endif()
elseif (${CMAKE_SYSTEM_PROCESSOR} MATCHES "ppc64") elseif (${CMAKE_SYSTEM_PROCESSOR} MATCHES "ppc64")
message(STATUS "PowerPC detected") message(STATUS "PowerPC detected")
if (${CMAKE_SYSTEM_PROCESSOR} MATCHES "ppc64le")
add_compile_options(-mcpu=powerpc64le)
else()
add_compile_options(-mcpu=native -mtune=native) add_compile_options(-mcpu=native -mtune=native)
#TODO: Add targets for Power8/Power9 (Altivec/VSX) and Power10(MMA) and query for big endian systems (ppc64/le/be) #TODO: Add targets for Power8/Power9 (Altivec/VSX) and Power10(MMA) and query for big endian systems (ppc64/le/be)
endif()
else() else()
message(STATUS "Unknown architecture") message(STATUS "Unknown architecture")
endif() endif()

View file

@ -342,6 +342,12 @@ ifneq ($(filter ppc64%,$(UNAME_M)),)
endif endif
endif endif
ifneq ($(filter ppc64le%,$(UNAME_M)),)
MK_CFLAGS += -mcpu=powerpc64le
MK_CXXFLAGS += -mcpu=powerpc64le
CUDA_POWER_ARCH = 1
endif
else else
MK_CFLAGS += -march=rv64gcv -mabi=lp64d MK_CFLAGS += -march=rv64gcv -mabi=lp64d
MK_CXXFLAGS += -march=rv64gcv -mabi=lp64d MK_CXXFLAGS += -march=rv64gcv -mabi=lp64d
@ -392,6 +398,8 @@ else
endif #LLAMA_CUDA_NVCC endif #LLAMA_CUDA_NVCC
ifdef CUDA_DOCKER_ARCH ifdef CUDA_DOCKER_ARCH
NVCCFLAGS += -Wno-deprecated-gpu-targets -arch=$(CUDA_DOCKER_ARCH) NVCCFLAGS += -Wno-deprecated-gpu-targets -arch=$(CUDA_DOCKER_ARCH)
else ifdef CUDA_POWER_ARCH
NVCCFLAGS +=
else else
NVCCFLAGS += -arch=native NVCCFLAGS += -arch=native
endif # CUDA_DOCKER_ARCH endif # CUDA_DOCKER_ARCH

View file

@ -1124,6 +1124,12 @@ std::string llama_detokenize_bpe(llama_context * ctx, const std::vector<llama_to
return result; return result;
} }
bool llama_should_add_bos_token(const llama_model * model) {
const int add_bos = llama_add_bos_token(model);
return add_bos != -1 ? bool(add_bos) : (llama_vocab_type(model) == LLAMA_VOCAB_TYPE_SPM);
}
// //
// YAML utils // YAML utils
// //
@ -1240,6 +1246,7 @@ void dump_string_yaml_multiline(FILE * stream, const char * prop_name, const cha
if (!data_str.empty() && (std::isspace(data_str[0]) || std::isspace(data_str.back()))) { if (!data_str.empty() && (std::isspace(data_str[0]) || std::isspace(data_str.back()))) {
data_str = std::regex_replace(data_str, std::regex("\n"), "\\n"); data_str = std::regex_replace(data_str, std::regex("\n"), "\\n");
data_str = std::regex_replace(data_str, std::regex("\""), "\\\""); data_str = std::regex_replace(data_str, std::regex("\""), "\\\"");
data_str = std::regex_replace(data_str, std::regex(R"(\\[^n"])"), R"(\$&)");
data_str = "\"" + data_str + "\""; data_str = "\"" + data_str + "\"";
fprintf(stream, "%s: %s\n", prop_name, data_str.c_str()); fprintf(stream, "%s: %s\n", prop_name, data_str.c_str());
return; return;

View file

@ -202,6 +202,10 @@ std::string llama_detokenize_bpe(
llama_context * ctx, llama_context * ctx,
const std::vector<llama_token> & tokens); const std::vector<llama_token> & tokens);
// Uses the value from the model metadata if possible, otherwise
// defaults to true when model type is SPM, otherwise false.
bool llama_should_add_bos_token(const llama_model * model);
// //
// YAML utils // YAML utils
// //

View file

@ -1136,6 +1136,7 @@ void print_common_train_usage(int /*argc*/, char ** /*argv*/, const struct train
fprintf(stderr, " --adam-beta2 N AdamW beta2 in interval [0,1). How much to smooth the second moment of gradients. (default %f)\n", params->adam_beta2); fprintf(stderr, " --adam-beta2 N AdamW beta2 in interval [0,1). How much to smooth the second moment of gradients. (default %f)\n", params->adam_beta2);
fprintf(stderr, " --adam-gclip N AdamW gradient clipping. Disabled when zero. (default %f)\n", params->adam_gclip); fprintf(stderr, " --adam-gclip N AdamW gradient clipping. Disabled when zero. (default %f)\n", params->adam_gclip);
fprintf(stderr, " --adam-epsf N AdamW epsilon for convergence test. Disabled when <= zero. (default %f)\n", params->adam_eps_f); fprintf(stderr, " --adam-epsf N AdamW epsilon for convergence test. Disabled when <= zero. (default %f)\n", params->adam_eps_f);
fprintf(stderr, " -ngl N, --n-gpu-layers N Number of model layers to offload to GPU (default %d)", params->n_gpu_layers);
fprintf(stderr, "\n"); fprintf(stderr, "\n");
} }
@ -1355,6 +1356,17 @@ bool consume_common_train_arg(
return true; return true;
} }
params->adam_gclip = std::stof(argv[i]); params->adam_gclip = std::stof(argv[i]);
} else if (arg == "-ngl" || arg == "--n-gpu-layers") {
if (++i >= argc) {
*invalid_param = true;
return true;
}
#ifdef LLAMA_SUPPORTS_GPU_OFFLOAD
params->n_gpu_layers = std::stoi(argv[i]);
#else
fprintf(stderr, "warning: not compiled with GPU offload support, --n-gpu-layers option will be ignored\n");
fprintf(stderr, "warning: see main README.md for information on enabling GPU BLAS support\n");
#endif
} else if (arg == "-h" || arg == "--help") { } else if (arg == "-h" || arg == "--help") {
params->print_usage = true; params->print_usage = true;
return true; return true;

View file

@ -6,11 +6,9 @@ from __future__ import annotations
import argparse import argparse
import json import json
import os import os
import struct
import sys import sys
from pathlib import Path from pathlib import Path
from typing import TYPE_CHECKING, Any from typing import TYPE_CHECKING, Any
import itertools
import numpy as np import numpy as np
import torch import torch
from sentencepiece import SentencePieceProcessor # type: ignore[import] from sentencepiece import SentencePieceProcessor # type: ignore[import]

View file

@ -193,7 +193,7 @@ class Model:
return gguf.MODEL_ARCH.MPT return gguf.MODEL_ARCH.MPT
if arch in ("BaichuanForCausalLM", "BaiChuanForCausalLM"): if arch in ("BaichuanForCausalLM", "BaiChuanForCausalLM"):
return gguf.MODEL_ARCH.BAICHUAN return gguf.MODEL_ARCH.BAICHUAN
if arch == "FalconForCausalLM": if arch in ("FalconForCausalLM", "RWForCausalLM"):
return gguf.MODEL_ARCH.FALCON return gguf.MODEL_ARCH.FALCON
if arch == "GPTBigCodeForCausalLM": if arch == "GPTBigCodeForCausalLM":
return gguf.MODEL_ARCH.STARCODER return gguf.MODEL_ARCH.STARCODER

View file

@ -2,7 +2,6 @@
from __future__ import annotations from __future__ import annotations
import argparse import argparse
import math
import struct import struct
import sys import sys
from enum import IntEnum from enum import IntEnum

View file

@ -690,6 +690,7 @@ def lazy_load_torch_file(outer_fp: IO[bytes], path: Path) -> ModelPlus:
data_base_path=pickle_paths[0][:-4], data_base_path=pickle_paths[0][:-4],
zip_file=zf) zip_file=zf)
model = unpickler.load() model = unpickler.load()
if 'model' in model: model = model['model']
as_dict = dict(model.items()) as_dict = dict(model.items())
return ModelPlus(model=as_dict, paths=[path], format='torch', vocab=None) return ModelPlus(model=as_dict, paths=[path], format='torch', vocab=None)

View file

@ -24,6 +24,7 @@ else()
add_subdirectory(llama-bench) add_subdirectory(llama-bench)
add_subdirectory(llava) add_subdirectory(llava)
add_subdirectory(main) add_subdirectory(main)
add_subdirectory(tokenize)
add_subdirectory(parallel) add_subdirectory(parallel)
add_subdirectory(perplexity) add_subdirectory(perplexity)
add_subdirectory(quantize) add_subdirectory(quantize)

View file

@ -3,9 +3,7 @@
import argparse import argparse
import gguf import gguf
import os
import struct import struct
import sys
import numpy as np import numpy as np
from pathlib import Path from pathlib import Path

View file

@ -548,35 +548,35 @@ static void randomize_lora(struct my_llama_lora * lora, int seed, float mean, fl
struct random_normal_distribution * rnd = init_random_normal_distribution(seed, mean, std, min, max); struct random_normal_distribution * rnd = init_random_normal_distribution(seed, mean, std, min, max);
randomize_tensor_normal(lora->tok_embeddings_a, rnd); randomize_tensor_normal(lora->tok_embeddings_a, rnd);
randomize_tensor_normal(lora->tok_embeddings_b, rnd); ggml_set_zero(lora->tok_embeddings_b);
randomize_tensor_normal(lora->norm_a, rnd); randomize_tensor_normal(lora->norm_a, rnd);
randomize_tensor_normal(lora->norm_b, rnd); ggml_set_zero(lora->norm_b);
randomize_tensor_normal(lora->output_a, rnd); randomize_tensor_normal(lora->output_a, rnd);
randomize_tensor_normal(lora->output_b, rnd); ggml_set_zero(lora->output_b);
for (uint32_t i = 0; i < n_layer; ++i) { for (uint32_t i = 0; i < n_layer; ++i) {
auto & layer = lora->layers[i]; auto & layer = lora->layers[i];
randomize_tensor_normal(layer.attention_norm_a, rnd); randomize_tensor_normal(layer.attention_norm_a, rnd);
randomize_tensor_normal(layer.attention_norm_b, rnd); ggml_set_zero(layer.attention_norm_b);
randomize_tensor_normal(layer.wq_a, rnd); randomize_tensor_normal(layer.wq_a, rnd);
randomize_tensor_normal(layer.wq_b, rnd); ggml_set_zero(layer.wq_b);
randomize_tensor_normal(layer.wk_a, rnd); randomize_tensor_normal(layer.wk_a, rnd);
randomize_tensor_normal(layer.wk_b, rnd); ggml_set_zero(layer.wk_b);
randomize_tensor_normal(layer.wv_a, rnd); randomize_tensor_normal(layer.wv_a, rnd);
randomize_tensor_normal(layer.wv_b, rnd); ggml_set_zero(layer.wv_b);
randomize_tensor_normal(layer.wo_a, rnd); randomize_tensor_normal(layer.wo_a, rnd);
randomize_tensor_normal(layer.wo_b, rnd); ggml_set_zero(layer.wo_b);
randomize_tensor_normal(layer.ffn_norm_a, rnd); randomize_tensor_normal(layer.ffn_norm_a, rnd);
randomize_tensor_normal(layer.ffn_norm_b, rnd); ggml_set_zero(layer.ffn_norm_b);
randomize_tensor_normal(layer.w1_a, rnd); randomize_tensor_normal(layer.w1_a, rnd);
randomize_tensor_normal(layer.w1_b, rnd); ggml_set_zero(layer.w1_b);
randomize_tensor_normal(layer.w2_a, rnd); randomize_tensor_normal(layer.w2_a, rnd);
randomize_tensor_normal(layer.w2_b, rnd); ggml_set_zero(layer.w2_b);
randomize_tensor_normal(layer.w3_a, rnd); randomize_tensor_normal(layer.w3_a, rnd);
randomize_tensor_normal(layer.w3_b, rnd); ggml_set_zero(layer.w3_b);
} }
free_random_normal_distribution(rnd); free_random_normal_distribution(rnd);
@ -1460,17 +1460,6 @@ static bool train_params_parse(int argc, char ** argv, struct train_params * par
} }
params->n_rank_w3 = std::stoi(argv[i]); params->n_rank_w3 = std::stoi(argv[i]);
params->custom_n_rank_w3 = true; params->custom_n_rank_w3 = true;
} else if (arg == "--gpu-layers" || arg == "-ngl" || arg == "--n-gpu-layers") {
if (++i >= argc) {
invalid_param = true;
break;
}
#ifdef LLAMA_SUPPORTS_GPU_OFFLOAD
params->common.n_gpu_layers = std::stoi(argv[i]);
#else
fprintf(stderr, "warning: not compiled with GPU offload support, --n-gpu-layers option will be ignored\n");
fprintf(stderr, "warning: see main README.md for information on enabling GPU BLAS support\n");
#endif
} else { } else {
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str()); fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
train_print_usage(argc, argv, &default_params); train_print_usage(argc, argv, &default_params);

View file

@ -230,7 +230,7 @@ int main(int argc, char ** argv) {
LOG_TEE("\n"); LOG_TEE("\n");
LOG_TEE("%s\n", get_system_info(params).c_str()); LOG_TEE("%s\n", get_system_info(params).c_str());
} }
const bool add_bos = llama_vocab_type(model) == LLAMA_VOCAB_TYPE_SPM; const bool add_bos = llama_should_add_bos_token(model);
LOG("add_bos: %d\n", add_bos); LOG("add_bos: %d\n", add_bos);
bool suff_rm_leading_spc = params.escape; bool suff_rm_leading_spc = params.escape;

View file

@ -208,9 +208,10 @@ static void process_prompt(struct llava_context * ctx_llava, struct llava_image_
int n_past = 0; int n_past = 0;
const int max_tgt_len = params->n_predict < 0 ? 256 : params->n_predict; const int max_tgt_len = params->n_predict < 0 ? 256 : params->n_predict;
const bool add_bos = llama_should_add_bos_token(llama_get_model(ctx_llava->ctx_llama));
// llava chat format is "<system_prompt>\nUSER:<image_embeddings>\n<textual_prompt>\nASSISTANT:" // llava chat format is "<system_prompt>\nUSER:<image_embeddings>\n<textual_prompt>\nASSISTANT:"
eval_string(ctx_llava->ctx_llama, "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.\nUSER:", params->n_batch, &n_past, true); eval_string(ctx_llava->ctx_llama, "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.\nUSER:", params->n_batch, &n_past, add_bos);
llava_eval_image_embed(ctx_llava->ctx_llama, image_embed, params->n_batch, &n_past); llava_eval_image_embed(ctx_llava->ctx_llama, image_embed, params->n_batch, &n_past);
eval_string(ctx_llava->ctx_llama, (prompt + "\nASSISTANT:").c_str(), params->n_batch, &n_past, false); eval_string(ctx_llava->ctx_llama, (prompt + "\nASSISTANT:").c_str(), params->n_batch, &n_past, false);

View file

@ -127,7 +127,14 @@ static bool load_file_to_bytes(const char* path, unsigned char** bytesOut, long
fclose(file); fclose(file);
return false; return false;
} }
fread(buffer, 1, fileSize, file); // Read the file into the buffer errno = 0;
size_t ret = fread(buffer, 1, fileSize, file); // Read the file into the buffer
if (ferror(file)) {
die_fmt("read error: %s", strerror(errno));
}
if (ret != (size_t) fileSize) {
die("unexpectedly reached end of file");
}
fclose(file); // Close the file fclose(file); // Close the file
*bytesOut = buffer; *bytesOut = buffer;

View file

@ -229,7 +229,7 @@ int main(int argc, char ** argv) {
} }
} }
const bool add_bos = llama_vocab_type(model) == LLAMA_VOCAB_TYPE_SPM; const bool add_bos = llama_should_add_bos_token(model);
LOG("add_bos: %d\n", add_bos); LOG("add_bos: %d\n", add_bos);
std::vector<llama_token> embd_inp; std::vector<llama_token> embd_inp;

View file

@ -149,8 +149,7 @@ static results_perplexity perplexity_v2(llama_context * ctx, const gpt_params &
// Output: `perplexity: 13.5106 [114/114]` // Output: `perplexity: 13.5106 [114/114]`
// BOS tokens will be added for each chunk before eval // BOS tokens will be added for each chunk before eval
const bool is_spm = llama_vocab_type(llama_get_model(ctx)) == LLAMA_VOCAB_TYPE_SPM; const bool add_bos = llama_should_add_bos_token(llama_get_model(ctx));
const bool add_bos = is_spm;
fprintf(stderr, "%s: tokenizing the input ..\n", __func__); fprintf(stderr, "%s: tokenizing the input ..\n", __func__);
@ -288,8 +287,7 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
// Output: `perplexity: 13.5106 [114/114]` // Output: `perplexity: 13.5106 [114/114]`
// BOS tokens will be added for each chunk before eval // BOS tokens will be added for each chunk before eval
const bool is_spm = llama_vocab_type(llama_get_model(ctx)) == LLAMA_VOCAB_TYPE_SPM; const bool add_bos = llama_should_add_bos_token(llama_get_model(ctx));
const bool add_bos = is_spm;
const int n_ctx = llama_n_ctx(ctx); const int n_ctx = llama_n_ctx(ctx);
auto tim1 = std::chrono::high_resolution_clock::now(); auto tim1 = std::chrono::high_resolution_clock::now();
@ -481,7 +479,7 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) {
fprintf(stderr, "================================= is_spm = %d\n", is_spm); fprintf(stderr, "================================= is_spm = %d\n", is_spm);
// This is needed as usual for LLaMA models // This is needed as usual for LLaMA models
const bool add_bos = is_spm; const bool add_bos = llama_should_add_bos_token(llama_get_model(ctx));
// Number of tasks to use when computing the score // Number of tasks to use when computing the score
if ( params.hellaswag_tasks < hs_task_count ) { if ( params.hellaswag_tasks < hs_task_count ) {

View file

@ -501,6 +501,7 @@ struct llama_server_context
bool multimodal = false; bool multimodal = false;
bool clean_kv_cache = true; bool clean_kv_cache = true;
bool all_slots_are_idle = false; bool all_slots_are_idle = false;
bool add_bos_token = true;
int32_t id_gen; int32_t id_gen;
int32_t n_ctx; // total context for all clients / slots int32_t n_ctx; // total context for all clients / slots
@ -573,6 +574,8 @@ struct llama_server_context
n_ctx = llama_n_ctx(ctx); n_ctx = llama_n_ctx(ctx);
add_bos_token = llama_should_add_bos_token(model);
return true; return true;
} }
@ -864,7 +867,7 @@ struct llama_server_context
} }
void update_system_prompt() { void update_system_prompt() {
system_tokens = ::llama_tokenize(ctx, system_prompt, true); system_tokens = ::llama_tokenize(ctx, system_prompt, add_bos_token);
llama_batch_clear(batch); llama_batch_clear(batch);
@ -1552,7 +1555,7 @@ struct llama_server_context
} }
else else
{ {
prompt_tokens = tokenize(slot.prompt, system_prompt.empty()); // add BOS if there isn't system prompt prompt_tokens = tokenize(slot.prompt, system_prompt.empty() && add_bos_token); // add BOS if there isn't system prompt
} }
slot.num_prompt_tokens = prompt_tokens.size(); slot.num_prompt_tokens = prompt_tokens.size();
@ -1629,7 +1632,7 @@ struct llama_server_context
const bool has_images = process_images(slot); const bool has_images = process_images(slot);
// process the prefix of first image // process the prefix of first image
std::vector<llama_token> prefix_tokens = has_images ? tokenize(slot.images[0].prefix_prompt, true) : prompt_tokens; std::vector<llama_token> prefix_tokens = has_images ? tokenize(slot.images[0].prefix_prompt, add_bos_token) : prompt_tokens;
for (; slot.n_past < (int) prefix_tokens.size(); ++slot.n_past) for (; slot.n_past < (int) prefix_tokens.size(); ++slot.n_past)
{ {
llama_batch_add(batch, prefix_tokens[slot.n_past], system_tokens.size() + slot.n_past, { slot.id }, false); llama_batch_add(batch, prefix_tokens[slot.n_past], system_tokens.size() + slot.n_past, { slot.id }, false);

View file

@ -0,0 +1,5 @@
set(TARGET tokenize)
add_executable(${TARGET} tokenize.cpp)
install(TARGETS ${TARGET} RUNTIME)
target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT})
target_compile_features(${TARGET} PRIVATE cxx_std_11)

View file

@ -0,0 +1,44 @@
#include "common.h"
#include "llama.h"
#include <cmath>
#include <cstdio>
#include <string>
#include <vector>
int main(int argc, char ** argv) {
if (argc < 3 || argv[1][0] == '-') {
printf("usage: %s MODEL_PATH PROMPT [--ids]\n" , argv[0]);
return 1;
}
const char * model_path = argv[1];
const char * prompt = argv[2];
const bool printing_ids = argc > 3 && std::string(argv[3]) == "--ids";
llama_backend_init(false);
llama_model_params model_params = llama_model_default_params();
model_params.vocab_only = true;
llama_model * model = llama_load_model_from_file(model_path, model_params);
llama_context_params ctx_params = llama_context_default_params();
llama_context * ctx = llama_new_context_with_model(model, ctx_params);
const bool add_bos = true;
std::vector<llama_token> tokens;
tokens = ::llama_tokenize(model, prompt, add_bos, true);
for (int i = 0; i < (int) tokens.size(); i++) {
if (printing_ids) {
printf("%d\n", tokens[i]);
} else {
printf("%6d -> '%s'\n", tokens[i], llama_token_to_piece(ctx, tokens[i]).c_str());
}
}
return 0;
}

View file

@ -88,6 +88,8 @@
#define CC_OFFSET_AMD 1000000 #define CC_OFFSET_AMD 1000000
#define CC_RDNA2 (CC_OFFSET_AMD + 1030) #define CC_RDNA2 (CC_OFFSET_AMD + 1030)
#define GGML_CUDA_MAX_NODES 8192
// define this if you want to always fallback to MMQ kernels and not use cuBLAS for matrix multiplication // define this if you want to always fallback to MMQ kernels and not use cuBLAS for matrix multiplication
// on modern hardware, using cuBLAS is recommended as it utilizes F16 tensor cores which are very performant // on modern hardware, using cuBLAS is recommended as it utilizes F16 tensor cores which are very performant
// for large computational tasks. the drawback is that this requires some extra amount of VRAM: // for large computational tasks. the drawback is that this requires some extra amount of VRAM:
@ -5838,7 +5840,7 @@ static void * ggml_cuda_pool_malloc(size_t size, size_t * actual_size) {
return ptr; return ptr;
} }
#ifdef DEBUG_CUDA_MALLOC #ifdef DEBUG_CUDA_MALLOC
fprintf(stderr, "%s: %d buffers, max_size = %u MB, tot_size = %u MB, requested %u MB\n", __func__, nnz, fprintf(stderr, "%s: %d buffers, max_size = %u MiB, tot_size = %u MiB, requested %u MiB\n", __func__, nnz,
(uint32_t)(max_size/1024/1024), (uint32_t)(tot_size/1024/1024), (uint32_t)(size/1024/1024)); (uint32_t)(max_size/1024/1024), (uint32_t)(tot_size/1024/1024), (uint32_t)(size/1024/1024));
#endif #endif
void * ptr; void * ptr;
@ -5976,7 +5978,7 @@ void * ggml_cuda_host_malloc(size_t size) {
// The allocation error can be bypassed. A null ptr will assigned out of this function. // The allocation error can be bypassed. A null ptr will assigned out of this function.
// This can fixed the OOM error in WSL. // This can fixed the OOM error in WSL.
cudaGetLastError(); cudaGetLastError();
fprintf(stderr, "WARNING: failed to allocate %.2f MB of pinned memory: %s\n", fprintf(stderr, "WARNING: failed to allocate %.2f MiB of pinned memory: %s\n",
size/1024.0/1024.0, cudaGetErrorString(err)); size/1024.0/1024.0, cudaGetErrorString(err));
return nullptr; return nullptr;
} }
@ -6354,6 +6356,7 @@ static int64_t get_row_rounding(ggml_type type) {
case GGML_TYPE_Q8_0: case GGML_TYPE_Q8_0:
return max_compute_capability >= CC_RDNA2 ? 128 : 64; return max_compute_capability >= CC_RDNA2 ? 128 : 64;
case GGML_TYPE_F16: case GGML_TYPE_F16:
case GGML_TYPE_F32:
return 1; return 1;
case GGML_TYPE_Q2_K: case GGML_TYPE_Q2_K:
return max_compute_capability >= CC_RDNA2 ? 128 : 32; return max_compute_capability >= CC_RDNA2 ? 128 : 32;
@ -6376,6 +6379,7 @@ static int64_t get_row_rounding(ggml_type type) {
case GGML_TYPE_Q8_0: case GGML_TYPE_Q8_0:
return 64; return 64;
case GGML_TYPE_F16: case GGML_TYPE_F16:
case GGML_TYPE_F32:
return 1; return 1;
case GGML_TYPE_Q2_K: case GGML_TYPE_Q2_K:
case GGML_TYPE_Q3_K: case GGML_TYPE_Q3_K:
@ -7727,7 +7731,7 @@ static void ggml_cuda_alibi(const ggml_tensor * src0, const ggml_tensor * src1,
ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_alibi); ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_alibi);
} }
void ggml_cuda_im2col(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { static void ggml_cuda_im2col(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_im2col); ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_im2col);
} }
@ -7842,11 +7846,11 @@ static size_t g_temp_tensor_extra_index = 0;
static ggml_tensor_extra_gpu * ggml_cuda_alloc_temp_tensor_extra() { static ggml_tensor_extra_gpu * ggml_cuda_alloc_temp_tensor_extra() {
if (g_temp_tensor_extras == nullptr) { if (g_temp_tensor_extras == nullptr) {
g_temp_tensor_extras = new ggml_tensor_extra_gpu[GGML_DEFAULT_GRAPH_SIZE]; g_temp_tensor_extras = new ggml_tensor_extra_gpu[GGML_CUDA_MAX_NODES];
} }
size_t alloc_index = g_temp_tensor_extra_index; size_t alloc_index = g_temp_tensor_extra_index;
g_temp_tensor_extra_index = (g_temp_tensor_extra_index + 1) % GGML_DEFAULT_GRAPH_SIZE; g_temp_tensor_extra_index = (g_temp_tensor_extra_index + 1) % GGML_CUDA_MAX_NODES;
ggml_tensor_extra_gpu * extra = &g_temp_tensor_extras[alloc_index]; ggml_tensor_extra_gpu * extra = &g_temp_tensor_extras[alloc_index];
memset(extra, 0, sizeof(*extra)); memset(extra, 0, sizeof(*extra));
@ -8173,11 +8177,11 @@ struct ggml_backend_buffer_context_cuda {
ggml_tensor_extra_gpu * ggml_cuda_alloc_temp_tensor_extra() { ggml_tensor_extra_gpu * ggml_cuda_alloc_temp_tensor_extra() {
if (temp_tensor_extras == nullptr) { if (temp_tensor_extras == nullptr) {
temp_tensor_extras = new ggml_tensor_extra_gpu[GGML_DEFAULT_GRAPH_SIZE]; temp_tensor_extras = new ggml_tensor_extra_gpu[GGML_CUDA_MAX_NODES];
} }
size_t alloc_index = temp_tensor_extra_index; size_t alloc_index = temp_tensor_extra_index;
temp_tensor_extra_index = (temp_tensor_extra_index + 1) % GGML_DEFAULT_GRAPH_SIZE; temp_tensor_extra_index = (temp_tensor_extra_index + 1) % GGML_CUDA_MAX_NODES;
ggml_tensor_extra_gpu * extra = &temp_tensor_extras[alloc_index]; ggml_tensor_extra_gpu * extra = &temp_tensor_extras[alloc_index];
memset(extra, 0, sizeof(*extra)); memset(extra, 0, sizeof(*extra));

View file

@ -346,9 +346,9 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
} }
GGML_METAL_LOG_INFO("%s: hasUnifiedMemory = %s\n", __func__, ctx->device.hasUnifiedMemory ? "true" : "false"); GGML_METAL_LOG_INFO("%s: hasUnifiedMemory = %s\n", __func__, ctx->device.hasUnifiedMemory ? "true" : "false");
GGML_METAL_LOG_INFO("%s: recommendedMaxWorkingSetSize = %8.2f MB\n", __func__, ctx->device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0); GGML_METAL_LOG_INFO("%s: recommendedMaxWorkingSetSize = %8.2f MiB\n", __func__, ctx->device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0);
if (ctx->device.maxTransferRate != 0) { if (ctx->device.maxTransferRate != 0) {
GGML_METAL_LOG_INFO("%s: maxTransferRate = %8.2f MB/s\n", __func__, ctx->device.maxTransferRate / 1024.0 / 1024.0); GGML_METAL_LOG_INFO("%s: maxTransferRate = %8.2f MiB/s\n", __func__, ctx->device.maxTransferRate / 1024.0 / 1024.0);
} else { } else {
GGML_METAL_LOG_INFO("%s: maxTransferRate = built-in GPU\n", __func__); GGML_METAL_LOG_INFO("%s: maxTransferRate = built-in GPU\n", __func__);
} }
@ -541,11 +541,11 @@ bool ggml_metal_add_buffer(
ctx->buffers[ctx->n_buffers].metal = [ctx->device newBufferWithBytesNoCopy:data length:size_aligned options:MTLResourceStorageModeShared deallocator:nil]; ctx->buffers[ctx->n_buffers].metal = [ctx->device newBufferWithBytesNoCopy:data length:size_aligned options:MTLResourceStorageModeShared deallocator:nil];
if (ctx->buffers[ctx->n_buffers].metal == nil) { if (ctx->buffers[ctx->n_buffers].metal == nil) {
GGML_METAL_LOG_ERROR("%s: error: failed to allocate '%-16s' buffer, size = %8.2f MB\n", __func__, name, size_aligned / 1024.0 / 1024.0); GGML_METAL_LOG_ERROR("%s: error: failed to allocate '%-16s' buffer, size = %8.2f MiB\n", __func__, name, size_aligned / 1024.0 / 1024.0);
return false; return false;
} }
GGML_METAL_LOG_INFO("%s: allocated '%-16s' buffer, size = %8.2f MB", __func__, name, size_aligned / 1024.0 / 1024.0); GGML_METAL_LOG_INFO("%s: allocated '%-16s' buffer, size = %8.2f MiB", __func__, name, size_aligned / 1024.0 / 1024.0);
++ctx->n_buffers; ++ctx->n_buffers;
} else { } else {
@ -565,11 +565,11 @@ bool ggml_metal_add_buffer(
ctx->buffers[ctx->n_buffers].metal = [ctx->device newBufferWithBytesNoCopy:(void *) ((uint8_t *) data + i) length:size_step_aligned options:MTLResourceStorageModeShared deallocator:nil]; ctx->buffers[ctx->n_buffers].metal = [ctx->device newBufferWithBytesNoCopy:(void *) ((uint8_t *) data + i) length:size_step_aligned options:MTLResourceStorageModeShared deallocator:nil];
if (ctx->buffers[ctx->n_buffers].metal == nil) { if (ctx->buffers[ctx->n_buffers].metal == nil) {
GGML_METAL_LOG_ERROR("%s: error: failed to allocate '%-16s' buffer, size = %8.2f MB\n", __func__, name, size_step_aligned / 1024.0 / 1024.0); GGML_METAL_LOG_ERROR("%s: error: failed to allocate '%-16s' buffer, size = %8.2f MiB\n", __func__, name, size_step_aligned / 1024.0 / 1024.0);
return false; return false;
} }
GGML_METAL_LOG_INFO("%s: allocated '%-16s' buffer, size = %8.2f MB, offs = %12ld", __func__, name, size_step_aligned / 1024.0 / 1024.0, i); GGML_METAL_LOG_INFO("%s: allocated '%-16s' buffer, size = %8.2f MiB, offs = %12ld", __func__, name, size_step_aligned / 1024.0 / 1024.0, i);
if (i + size_step < size) { if (i + size_step < size) {
GGML_METAL_LOG_INFO("\n"); GGML_METAL_LOG_INFO("\n");
} }

View file

@ -19,7 +19,7 @@
#ifdef __wasm_simd128__ #ifdef __wasm_simd128__
#include <wasm_simd128.h> #include <wasm_simd128.h>
#else #else
#ifdef __POWER9_VECTOR__ #if defined(__POWER9_VECTOR__) || defined(__powerpc64__)
#include <altivec.h> #include <altivec.h>
#undef bool #undef bool
#define bool _Bool #define bool _Bool

104
ggml.c
View file

@ -9611,10 +9611,12 @@ static void ggml_compute_forward_out_prod_f32(
const int ith = params->ith; const int ith = params->ith;
const int nth = params->nth; const int nth = params->nth;
GGML_ASSERT(ne0 == ne00);
GGML_ASSERT(ne1 == ne10);
GGML_ASSERT(ne2 == ne02);
GGML_ASSERT(ne02 == ne12); GGML_ASSERT(ne02 == ne12);
GGML_ASSERT(ne03 == ne13);
GGML_ASSERT(ne2 == ne12);
GGML_ASSERT(ne3 == ne13); GGML_ASSERT(ne3 == ne13);
GGML_ASSERT(ne03 == ne13);
// we don't support permuted src0 or src1 // we don't support permuted src0 or src1
GGML_ASSERT(nb00 == sizeof(float)); GGML_ASSERT(nb00 == sizeof(float));
@ -9625,18 +9627,25 @@ static void ggml_compute_forward_out_prod_f32(
// GGML_ASSERT(nb1 <= nb2); // GGML_ASSERT(nb1 <= nb2);
// GGML_ASSERT(nb2 <= nb3); // GGML_ASSERT(nb2 <= nb3);
GGML_ASSERT(ne0 == ne00);
GGML_ASSERT(ne1 == ne10);
GGML_ASSERT(ne2 == ne02);
GGML_ASSERT(ne3 == ne03);
// nb01 >= nb00 - src0 is not transposed // nb01 >= nb00 - src0 is not transposed
// compute by src0 rows // compute by src0 rows
// TODO: #if defined(GGML_USE_CUBLAS) ggml_cuda_out_prod // TODO: #if defined(GGML_USE_CUBLAS) ggml_cuda_out_prod
// TODO: #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CLBLAST) // TODO: #if defined(GGML_USE_CLBLAST)
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
bool use_blas = ggml_is_matrix(src0) &&
ggml_is_matrix(src1) &&
ggml_is_contiguous(src0) &&
(ggml_is_contiguous(src1) || ggml_is_transposed(src1));
#endif
if (params->type == GGML_TASK_INIT) { if (params->type == GGML_TASK_INIT) {
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) // gemm beta will zero dst
if (use_blas) {
return;
}
#endif
ggml_vec_set_f32(ne0*ne1*ne2*ne3, dst->data, 0); ggml_vec_set_f32(ne0*ne1*ne2*ne3, dst->data, 0);
return; return;
} }
@ -9645,6 +9654,50 @@ static void ggml_compute_forward_out_prod_f32(
return; return;
} }
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
if (use_blas) {
if (params->ith != 0) { // All threads other than the first do no work.
return;
}
// Arguments to ggml_compute_forward_out_prod (expressed as major,minor)
// src0: (k,n)
// src1: (k,m)
// dst: (m,n)
//
// Arguments to sgemm (see https://github.com/Reference-LAPACK/lapack/blob/master/BLAS/SRC/sgemm.f)
// Also expressed as (major,minor)
// a: (m,k): so src1 transposed
// b: (k,n): so src0
// c: (m,n)
//
// However, if ggml_is_transposed(src1) is true, then
// src1->data already contains a transposed version, so sgemm mustn't
// transpose it further.
int n = src0->ne[0];
int k = src0->ne[1];
int m = src1->ne[0];
int transposeA, lda;
if (!ggml_is_transposed(src1)) {
transposeA = CblasTrans;
lda = m;
} else {
transposeA = CblasNoTrans;
lda = k;
}
float * a = (float *) ((char *) src1->data);
float * b = (float *) ((char *) src0->data);
float * c = (float *) ((char *) dst->data);
cblas_sgemm(CblasRowMajor, transposeA, CblasNoTrans, m, n, k, 1.0, a, lda, b, n, 0.0, c, n);
return;
}
#endif
// dst[:,:,:,:] = 0 // dst[:,:,:,:] = 0
// for i2,i3: // for i2,i3:
// for i1: // for i1:
@ -18073,7 +18126,7 @@ struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_p
{ {
ctx->kv = malloc(ctx->header.n_kv * sizeof(struct gguf_kv)); ctx->kv = malloc(ctx->header.n_kv * sizeof(struct gguf_kv));
for (uint32_t i = 0; i < ctx->header.n_kv; ++i) { for (uint64_t i = 0; i < ctx->header.n_kv; ++i) {
struct gguf_kv * kv = &ctx->kv[i]; struct gguf_kv * kv = &ctx->kv[i];
//fprintf(stderr, "%s: reading kv %d\n", __func__, i); //fprintf(stderr, "%s: reading kv %d\n", __func__, i);
@ -18120,7 +18173,7 @@ struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_p
case GGUF_TYPE_STRING: case GGUF_TYPE_STRING:
{ {
kv->value.arr.data = malloc(kv->value.arr.n * sizeof(struct gguf_str)); kv->value.arr.data = malloc(kv->value.arr.n * sizeof(struct gguf_str));
for (uint32_t j = 0; j < kv->value.arr.n; ++j) { for (uint64_t j = 0; j < kv->value.arr.n; ++j) {
ok = ok && gguf_fread_str(file, &((struct gguf_str *) kv->value.arr.data)[j], &offset); ok = ok && gguf_fread_str(file, &((struct gguf_str *) kv->value.arr.data)[j], &offset);
} }
} break; } break;
@ -18148,7 +18201,7 @@ struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_p
{ {
ctx->infos = malloc(ctx->header.n_tensors * sizeof(struct gguf_tensor_info)); ctx->infos = malloc(ctx->header.n_tensors * sizeof(struct gguf_tensor_info));
for (uint32_t i = 0; i < ctx->header.n_tensors; ++i) { for (uint64_t i = 0; i < ctx->header.n_tensors; ++i) {
struct gguf_tensor_info * info = &ctx->infos[i]; struct gguf_tensor_info * info = &ctx->infos[i];
for (int j = 0; j < GGML_MAX_DIMS; ++j) { for (int j = 0; j < GGML_MAX_DIMS; ++j) {
@ -18195,7 +18248,7 @@ struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_p
// compute the total size of the data section, taking into account the alignment // compute the total size of the data section, taking into account the alignment
{ {
ctx->size = 0; ctx->size = 0;
for (uint32_t i = 0; i < ctx->header.n_tensors; ++i) { for (uint64_t i = 0; i < ctx->header.n_tensors; ++i) {
struct gguf_tensor_info * info = &ctx->infos[i]; struct gguf_tensor_info * info = &ctx->infos[i];
const int64_t ne = const int64_t ne =
@ -18264,7 +18317,7 @@ struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_p
ggml_set_no_alloc(ctx_data, true); ggml_set_no_alloc(ctx_data, true);
// create the tensors // create the tensors
for (uint32_t i = 0; i < ctx->header.n_tensors; ++i) { for (uint64_t i = 0; i < ctx->header.n_tensors; ++i) {
const int64_t ne[GGML_MAX_DIMS] = { const int64_t ne[GGML_MAX_DIMS] = {
ctx->infos[i].ne[0], ctx->infos[i].ne[0],
ctx->infos[i].ne[1], ctx->infos[i].ne[1],
@ -18399,24 +18452,29 @@ int gguf_find_key(const struct gguf_context * ctx, const char * key) {
} }
const char * gguf_get_key(const struct gguf_context * ctx, int key_id) { const char * gguf_get_key(const struct gguf_context * ctx, int key_id) {
GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
return ctx->kv[key_id].key.data; return ctx->kv[key_id].key.data;
} }
enum gguf_type gguf_get_kv_type(const struct gguf_context * ctx, int key_id) { enum gguf_type gguf_get_kv_type(const struct gguf_context * ctx, int key_id) {
GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
return ctx->kv[key_id].type; return ctx->kv[key_id].type;
} }
enum gguf_type gguf_get_arr_type(const struct gguf_context * ctx, int key_id) { enum gguf_type gguf_get_arr_type(const struct gguf_context * ctx, int key_id) {
GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_ARRAY); GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_ARRAY);
return ctx->kv[key_id].value.arr.type; return ctx->kv[key_id].value.arr.type;
} }
const void * gguf_get_arr_data(const struct gguf_context * ctx, int key_id) { const void * gguf_get_arr_data(const struct gguf_context * ctx, int key_id) {
GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_ARRAY); GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_ARRAY);
return ctx->kv[key_id].value.arr.data; return ctx->kv[key_id].value.arr.data;
} }
const char * gguf_get_arr_str(const struct gguf_context * ctx, int key_id, int i) { const char * gguf_get_arr_str(const struct gguf_context * ctx, int key_id, int i) {
GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_ARRAY); GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_ARRAY);
struct gguf_kv * kv = &ctx->kv[key_id]; struct gguf_kv * kv = &ctx->kv[key_id];
struct gguf_str * str = &((struct gguf_str *) kv->value.arr.data)[i]; struct gguf_str * str = &((struct gguf_str *) kv->value.arr.data)[i];
@ -18424,70 +18482,90 @@ const char * gguf_get_arr_str(const struct gguf_context * ctx, int key_id, int i
} }
int gguf_get_arr_n(const struct gguf_context * ctx, int key_id) { int gguf_get_arr_n(const struct gguf_context * ctx, int key_id) {
GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_ARRAY); GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_ARRAY);
return ctx->kv[key_id].value.arr.n; return ctx->kv[key_id].value.arr.n;
} }
uint8_t gguf_get_val_u8(const struct gguf_context * ctx, int key_id) { uint8_t gguf_get_val_u8(const struct gguf_context * ctx, int key_id) {
GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_UINT8); GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_UINT8);
return ctx->kv[key_id].value.uint8; return ctx->kv[key_id].value.uint8;
} }
int8_t gguf_get_val_i8(const struct gguf_context * ctx, int key_id) { int8_t gguf_get_val_i8(const struct gguf_context * ctx, int key_id) {
GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_INT8); GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_INT8);
return ctx->kv[key_id].value.int8; return ctx->kv[key_id].value.int8;
} }
uint16_t gguf_get_val_u16(const struct gguf_context * ctx, int key_id) { uint16_t gguf_get_val_u16(const struct gguf_context * ctx, int key_id) {
GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_UINT16); GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_UINT16);
return ctx->kv[key_id].value.uint16; return ctx->kv[key_id].value.uint16;
} }
int16_t gguf_get_val_i16(const struct gguf_context * ctx, int key_id) { int16_t gguf_get_val_i16(const struct gguf_context * ctx, int key_id) {
GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_INT16); GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_INT16);
return ctx->kv[key_id].value.int16; return ctx->kv[key_id].value.int16;
} }
uint32_t gguf_get_val_u32(const struct gguf_context * ctx, int key_id) { uint32_t gguf_get_val_u32(const struct gguf_context * ctx, int key_id) {
GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_UINT32); GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_UINT32);
return ctx->kv[key_id].value.uint32; return ctx->kv[key_id].value.uint32;
} }
int32_t gguf_get_val_i32(const struct gguf_context * ctx, int key_id) { int32_t gguf_get_val_i32(const struct gguf_context * ctx, int key_id) {
GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_INT32); GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_INT32);
return ctx->kv[key_id].value.int32; return ctx->kv[key_id].value.int32;
} }
float gguf_get_val_f32(const struct gguf_context * ctx, int key_id) { float gguf_get_val_f32(const struct gguf_context * ctx, int key_id) {
GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_FLOAT32); GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_FLOAT32);
return ctx->kv[key_id].value.float32; return ctx->kv[key_id].value.float32;
} }
uint64_t gguf_get_val_u64(const struct gguf_context * ctx, int key_id) { uint64_t gguf_get_val_u64(const struct gguf_context * ctx, int key_id) {
GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_UINT64); GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_UINT64);
return ctx->kv[key_id].value.uint64; return ctx->kv[key_id].value.uint64;
} }
int64_t gguf_get_val_i64(const struct gguf_context * ctx, int key_id) { int64_t gguf_get_val_i64(const struct gguf_context * ctx, int key_id) {
GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_INT64); GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_INT64);
return ctx->kv[key_id].value.int64; return ctx->kv[key_id].value.int64;
} }
double gguf_get_val_f64(const struct gguf_context * ctx, int key_id) { double gguf_get_val_f64(const struct gguf_context * ctx, int key_id) {
GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_FLOAT64); GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_FLOAT64);
return ctx->kv[key_id].value.float64; return ctx->kv[key_id].value.float64;
} }
bool gguf_get_val_bool(const struct gguf_context * ctx, int key_id) { bool gguf_get_val_bool(const struct gguf_context * ctx, int key_id) {
GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_BOOL); GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_BOOL);
return ctx->kv[key_id].value.bool_; return ctx->kv[key_id].value.bool_;
} }
const char * gguf_get_val_str(const struct gguf_context * ctx, int key_id) { const char * gguf_get_val_str(const struct gguf_context * ctx, int key_id) {
GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_STRING); GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_STRING);
return ctx->kv[key_id].value.str.data; return ctx->kv[key_id].value.str.data;
} }
const void * gguf_get_val_data(const struct gguf_context * ctx, int key_id) {
GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
GGML_ASSERT(ctx->kv[key_id].type != GGUF_TYPE_ARRAY);
GGML_ASSERT(ctx->kv[key_id].type != GGUF_TYPE_STRING);
return &ctx->kv[key_id].value;
}
int gguf_get_n_tensors(const struct gguf_context * ctx) { int gguf_get_n_tensors(const struct gguf_context * ctx) {
return ctx->header.n_tensors; return ctx->header.n_tensors;
} }

1
ggml.h
View file

@ -2045,6 +2045,7 @@ extern "C" {
GGML_API double gguf_get_val_f64 (const struct gguf_context * ctx, int key_id); GGML_API double gguf_get_val_f64 (const struct gguf_context * ctx, int key_id);
GGML_API bool gguf_get_val_bool(const struct gguf_context * ctx, int key_id); GGML_API bool gguf_get_val_bool(const struct gguf_context * ctx, int key_id);
GGML_API const char * gguf_get_val_str (const struct gguf_context * ctx, int key_id); GGML_API const char * gguf_get_val_str (const struct gguf_context * ctx, int key_id);
GGML_API const void * gguf_get_val_data(const struct gguf_context * ctx, int key_id);
GGML_API int gguf_get_arr_n (const struct gguf_context * ctx, int key_id); GGML_API int gguf_get_arr_n (const struct gguf_context * ctx, int key_id);
GGML_API const void * gguf_get_arr_data(const struct gguf_context * ctx, int key_id); GGML_API const void * gguf_get_arr_data(const struct gguf_context * ctx, int key_id);
GGML_API const char * gguf_get_arr_str (const struct gguf_context * ctx, int key_id, int i); GGML_API const char * gguf_get_arr_str (const struct gguf_context * ctx, int key_id, int i);

View file

@ -117,17 +117,18 @@ class SpecialVocab:
def _try_load_from_tokenizer_json(self, path: Path) -> bool: def _try_load_from_tokenizer_json(self, path: Path) -> bool:
tokenizer_file = path / 'tokenizer.json' tokenizer_file = path / 'tokenizer.json'
if not tokenizer_file.is_file(): if tokenizer_file.is_file():
return False
with open(tokenizer_file, encoding = 'utf-8') as f: with open(tokenizer_file, encoding = 'utf-8') as f:
tokenizer = json.load(f) tokenizer = json.load(f)
if self.load_merges: if self.load_merges:
merges = tokenizer.get('model', {}).get('merges') merges = tokenizer.get('model', {}).get('merges')
if isinstance(merges, list) and merges and isinstance(merges[0], str): if isinstance(merges, list) and merges and isinstance(merges[0], str):
self.merges = merges self.merges = merges
added_tokens = tokenizer.get('added_tokens', {})
else:
added_tokens = {}
tokenizer_config_file = path / 'tokenizer_config.json' tokenizer_config_file = path / 'tokenizer_config.json'
added_tokens = tokenizer.get('added_tokens') if not tokenizer_config_file.is_file():
if added_tokens is None or not tokenizer_config_file.is_file():
return True return True
with open(tokenizer_config_file, encoding = 'utf-8') as f: with open(tokenizer_config_file, encoding = 'utf-8') as f:
tokenizer_config = json.load(f) tokenizer_config = json.load(f)
@ -135,6 +136,10 @@ class SpecialVocab:
add_entry = tokenizer_config.get(f'add_{typ}_token') add_entry = tokenizer_config.get(f'add_{typ}_token')
if isinstance(add_entry, bool): if isinstance(add_entry, bool):
self.add_special_token[typ] = add_entry self.add_special_token[typ] = add_entry
if not added_tokens:
# We will need this to get the content for the token, so if it's empty
# may as well just give up.
continue
entry = tokenizer_config.get(f'{typ}_token') entry = tokenizer_config.get(f'{typ}_token')
if isinstance(entry, str): if isinstance(entry, str):
tc_content = entry tc_content = entry

View file

@ -1,6 +1,6 @@
[tool.poetry] [tool.poetry]
name = "gguf" name = "gguf"
version = "0.5.2" version = "0.5.3"
description = "Read and write ML models in GGUF for GGML" description = "Read and write ML models in GGUF for GGML"
authors = ["GGML <ggml@ggml.ai>"] authors = ["GGML <ggml@ggml.ai>"]
packages = [ packages = [

View file

@ -86,6 +86,7 @@ def dump_metadata_json(reader: GGUFReader, args: argparse.Namespace) -> None:
curr["value"] = str(bytes(field.parts[-1]), encoding="utf-8") curr["value"] = str(bytes(field.parts[-1]), encoding="utf-8")
else: else:
curr["value"] = field.parts[-1].tolist()[0] curr["value"] = field.parts[-1].tolist()[0]
if not args.no_tensors:
for idx, tensor in enumerate(reader.tensors): for idx, tensor in enumerate(reader.tensors):
tensors[tensor.name] = { tensors[tensor.name] = {
"index": idx, "index": idx,

186
llama.cpp
View file

@ -92,7 +92,7 @@
#define LLAMA_ATTRIBUTE_FORMAT(...) #define LLAMA_ATTRIBUTE_FORMAT(...)
#endif #endif
#define LLAMA_MAX_NODES 4096 #define LLAMA_MAX_NODES 8192
// //
// logging // logging
@ -256,6 +256,8 @@ enum llm_kv {
LLM_KV_TOKENIZER_UNK_ID, LLM_KV_TOKENIZER_UNK_ID,
LLM_KV_TOKENIZER_SEP_ID, LLM_KV_TOKENIZER_SEP_ID,
LLM_KV_TOKENIZER_PAD_ID, LLM_KV_TOKENIZER_PAD_ID,
LLM_KV_TOKENIZER_ADD_BOS,
LLM_KV_TOKENIZER_ADD_EOS,
LLM_KV_TOKENIZER_HF_JSON, LLM_KV_TOKENIZER_HF_JSON,
LLM_KV_TOKENIZER_RWKV, LLM_KV_TOKENIZER_RWKV,
}; };
@ -304,6 +306,8 @@ static std::map<llm_kv, std::string> LLM_KV_NAMES = {
{ LLM_KV_TOKENIZER_UNK_ID, "tokenizer.ggml.unknown_token_id" }, { LLM_KV_TOKENIZER_UNK_ID, "tokenizer.ggml.unknown_token_id" },
{ LLM_KV_TOKENIZER_SEP_ID, "tokenizer.ggml.seperator_token_id" }, { LLM_KV_TOKENIZER_SEP_ID, "tokenizer.ggml.seperator_token_id" },
{ LLM_KV_TOKENIZER_PAD_ID, "tokenizer.ggml.padding_token_id" }, { LLM_KV_TOKENIZER_PAD_ID, "tokenizer.ggml.padding_token_id" },
{ LLM_KV_TOKENIZER_ADD_BOS, "tokenizer.ggml.add_bos_token" },
{ LLM_KV_TOKENIZER_ADD_EOS, "tokenizer.ggml.add_eos_token" },
{ LLM_KV_TOKENIZER_HF_JSON, "tokenizer.huggingface.json" }, { LLM_KV_TOKENIZER_HF_JSON, "tokenizer.huggingface.json" },
{ LLM_KV_TOKENIZER_RWKV, "tokenizer.rwkv.world" }, { LLM_KV_TOKENIZER_RWKV, "tokenizer.rwkv.world" },
}; };
@ -586,6 +590,60 @@ static int8_t llama_rope_scaling_type_from_string(const std::string & name) {
return LLAMA_ROPE_SCALING_UNSPECIFIED; return LLAMA_ROPE_SCALING_UNSPECIFIED;
} }
static std::string gguf_data_to_str(enum gguf_type type, const void * data, int i) {
switch (type) {
case GGUF_TYPE_UINT8: return std::to_string(((const uint8_t *)data)[i]);
case GGUF_TYPE_INT8: return std::to_string(((const int8_t *)data)[i]);
case GGUF_TYPE_UINT16: return std::to_string(((const uint16_t *)data)[i]);
case GGUF_TYPE_INT16: return std::to_string(((const int16_t *)data)[i]);
case GGUF_TYPE_UINT32: return std::to_string(((const uint32_t *)data)[i]);
case GGUF_TYPE_INT32: return std::to_string(((const int32_t *)data)[i]);
case GGUF_TYPE_UINT64: return std::to_string(((const uint64_t *)data)[i]);
case GGUF_TYPE_INT64: return std::to_string(((const int64_t *)data)[i]);
case GGUF_TYPE_FLOAT32: return std::to_string(((const float *)data)[i]);
case GGUF_TYPE_FLOAT64: return std::to_string(((const double *)data)[i]);
case GGUF_TYPE_BOOL: return ((const bool *)data)[i] ? "true" : "false";
default: return format("unknown type %d", type);
}
}
static std::string gguf_kv_to_str(struct gguf_context * ctx_gguf, int i) {
const enum gguf_type type = gguf_get_kv_type(ctx_gguf, i);
switch (type) {
case GGUF_TYPE_STRING:
return gguf_get_val_str(ctx_gguf, i);
case GGUF_TYPE_ARRAY:
{
const enum gguf_type arr_type = gguf_get_arr_type(ctx_gguf, i);
int arr_n = gguf_get_arr_n(ctx_gguf, i);
const void * data = gguf_get_arr_data(ctx_gguf, i);
std::stringstream ss;
ss << "[";
for (int j = 0; j < arr_n; j++) {
if (arr_type == GGUF_TYPE_STRING) {
std::string val = gguf_get_arr_str(ctx_gguf, i, j);
// escape quotes
replace_all(val, "\\", "\\\\");
replace_all(val, "\"", "\\\"");
ss << '"' << val << '"';
} else if (arr_type == GGUF_TYPE_ARRAY) {
ss << "???";
} else {
ss << gguf_data_to_str(arr_type, data, j);
}
if (j < arr_n - 1) {
ss << ", ";
}
}
ss << "]";
return ss.str();
}
default:
return gguf_data_to_str(type, gguf_get_val_data(ctx_gguf, i), 0);
}
}
// //
// ggml helpers // ggml helpers
// //
@ -1069,9 +1127,9 @@ enum e_model {
MODEL_70B, MODEL_70B,
}; };
static const size_t kB = 1024; static const size_t kiB = 1024;
static const size_t MB = 1024*kB; static const size_t MiB = 1024*kiB;
static const size_t GB = 1024*MB; static const size_t GiB = 1024*MiB;
struct llama_hparams { struct llama_hparams {
bool vocab_only; bool vocab_only;
@ -1262,6 +1320,9 @@ struct llama_vocab {
id special_sep_id = -1; id special_sep_id = -1;
id special_pad_id = -1; id special_pad_id = -1;
int special_add_bos = -1; // -1 unknown, 1 add, 0 don't add.
int special_add_eos = -1; // -1 unknown, 1 add, 0 don't add.
id linefeed_id = 13; id linefeed_id = 13;
id special_prefix_id = 32007; id special_prefix_id = 32007;
id special_middle_id = 32009; id special_middle_id = 32009;
@ -1306,6 +1367,9 @@ struct llama_model {
int n_gpu_layers; int n_gpu_layers;
// gguf metadata
std::unordered_map<std::string, std::string> gguf_kv;
// context // context
struct ggml_context * ctx = NULL; struct ggml_context * ctx = NULL;
@ -1467,7 +1531,7 @@ static bool llama_kv_cache_init(
vram_kv_cache += ggml_nbytes(cache.k); vram_kv_cache += ggml_nbytes(cache.k);
} }
if (vram_kv_cache > 0) { if (vram_kv_cache > 0) {
LLAMA_LOG_INFO("%s: VRAM kv self = %.2f MB\n", __func__, vram_kv_cache / 1024.0 / 1024.0); LLAMA_LOG_INFO("%s: VRAM kv self = %.2f MiB\n", __func__, vram_kv_cache / 1024.0 / 1024.0);
} }
} }
#endif #endif
@ -1962,8 +2026,18 @@ struct llama_model_loader {
for (int i = 0; i < n_kv; i++) { for (int i = 0; i < n_kv; i++) {
const char * name = gguf_get_key(ctx_gguf, i); const char * name = gguf_get_key(ctx_gguf, i);
const enum gguf_type type = gguf_get_kv_type(ctx_gguf, i); const enum gguf_type type = gguf_get_kv_type(ctx_gguf, i);
const std::string type_name =
type == GGUF_TYPE_ARRAY
? format("%s[%s,%d]", gguf_type_name(type), gguf_type_name(gguf_get_arr_type(ctx_gguf, i)), gguf_get_arr_n(ctx_gguf, i))
: gguf_type_name(type);
LLAMA_LOG_INFO("%s: - kv %3d: %42s %-8s\n", __func__, i, name, gguf_type_name(type)); std::string value = gguf_kv_to_str(ctx_gguf, i);
const size_t MAX_VALUE_LEN = 40;
if (value.size() > MAX_VALUE_LEN) {
value = format("%s...", value.substr(0, MAX_VALUE_LEN - 3).c_str());
}
LLAMA_LOG_INFO("%s: - kv %3d: %42s %-16s = %s\n", __func__, i, name, type_name.c_str(), value.c_str());
} }
// print type counts // print type counts
@ -2294,6 +2368,17 @@ static void llm_load_hparams(
llama_model & model) { llama_model & model) {
auto & hparams = model.hparams; auto & hparams = model.hparams;
// get metadata as string
for (int i = 0; i < gguf_get_n_kv(ctx); i++) {
enum gguf_type type = gguf_get_kv_type(ctx, i);
if (type == GGUF_TYPE_ARRAY) {
continue;
}
const char * name = gguf_get_key(ctx, i);
const std::string value = gguf_kv_to_str(ctx, i);
model.gguf_kv.emplace(name, value);
}
// get general kv // get general kv
ml.get_key(LLM_KV_GENERAL_NAME, model.name, false); ml.get_key(LLM_KV_GENERAL_NAME, model.name, false);
@ -2589,6 +2674,23 @@ static void llm_load_vocab(
} else { } else {
id = new_id; id = new_id;
} }
}
// Handle add_bos_token and add_eos_token
std::string key = kv(LLM_KV_TOKENIZER_ADD_BOS);
int kid = gguf_find_key(ctx, key.c_str());
enum gguf_type ktype = kid < 0 ? GGUF_TYPE_COUNT : gguf_get_kv_type(ctx, kid);
vocab.special_add_bos = ktype == GGUF_TYPE_BOOL ? gguf_get_val_bool(ctx, kid) : -1;
if (ktype != GGUF_TYPE_BOOL && ktype != GGUF_TYPE_COUNT) {
LLAMA_LOG_WARN("%s: bad field type %d for '%s' - ignoring\n", __func__, ktype, key.c_str());
}
key = kv(LLM_KV_TOKENIZER_ADD_EOS);
kid = gguf_find_key(ctx, key.c_str());
ktype = kid < 0 ? GGUF_TYPE_COUNT : gguf_get_kv_type(ctx, kid);
vocab.special_add_eos = ktype == GGUF_TYPE_BOOL ? gguf_get_val_bool(ctx, kid) : -1;
if (ktype != GGUF_TYPE_BOOL && ktype != GGUF_TYPE_COUNT) {
LLAMA_LOG_WARN("%s: bad field type %d for '%s' - ignoring\n", __func__, ktype, key.c_str());
} }
} }
@ -2720,7 +2822,7 @@ static void llm_load_print_meta(llama_model_loader & ml, llama_model & model) {
LLAMA_LOG_INFO("%s: model type = %s\n", __func__, llama_model_type_name(model.type)); LLAMA_LOG_INFO("%s: model type = %s\n", __func__, llama_model_type_name(model.type));
LLAMA_LOG_INFO("%s: model ftype = %s\n", __func__, llama_model_ftype_name(model.ftype).c_str()); LLAMA_LOG_INFO("%s: model ftype = %s\n", __func__, llama_model_ftype_name(model.ftype).c_str());
LLAMA_LOG_INFO("%s: model params = %.2f B\n", __func__, ml.n_elements*1e-9); LLAMA_LOG_INFO("%s: model params = %.2f B\n", __func__, ml.n_elements*1e-9);
if (ml.n_bytes < GB) { if (ml.n_bytes < GiB) {
LLAMA_LOG_INFO("%s: model size = %.2f MiB (%.2f BPW) \n", __func__, ml.n_bytes/1024.0/1024.0, ml.n_bytes*8.0/ml.n_elements); LLAMA_LOG_INFO("%s: model size = %.2f MiB (%.2f BPW) \n", __func__, ml.n_bytes/1024.0/1024.0, ml.n_bytes*8.0/ml.n_elements);
} else { } else {
LLAMA_LOG_INFO("%s: model size = %.2f GiB (%.2f BPW) \n", __func__, ml.n_bytes/1024.0/1024.0/1024.0, ml.n_bytes*8.0/ml.n_elements); LLAMA_LOG_INFO("%s: model size = %.2f GiB (%.2f BPW) \n", __func__, ml.n_bytes/1024.0/1024.0/1024.0, ml.n_bytes*8.0/ml.n_elements);
@ -2759,7 +2861,7 @@ static void llm_load_tensors(
ml.calc_sizes(ctx_size, mmapped_size); ml.calc_sizes(ctx_size, mmapped_size);
LLAMA_LOG_INFO("%s: ggml ctx size = %7.2f MB\n", __func__, ctx_size/1024.0/1024.0); LLAMA_LOG_INFO("%s: ggml ctx size = %7.2f MiB\n", __func__, ctx_size/1024.0/1024.0);
// create the ggml context // create the ggml context
{ {
@ -3408,7 +3510,7 @@ static void llm_load_tensors(
ctx_size + ctx_size +
mmapped_size - vram_weights; // weights in VRAM not in memory mmapped_size - vram_weights; // weights in VRAM not in memory
LLAMA_LOG_INFO("%s: mem required = %7.2f MB\n", __func__, mem_required / 1024.0 / 1024.0); LLAMA_LOG_INFO("%s: mem required = %7.2f MiB\n", __func__, mem_required / 1024.0 / 1024.0);
#if defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST) #if defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST)
const int n_gpu = std::min(n_gpu_layers, int(hparams.n_layer)); const int n_gpu = std::min(n_gpu_layers, int(hparams.n_layer));
@ -3427,7 +3529,7 @@ static void llm_load_tensors(
#endif // GGML_USE_CUBLAS #endif // GGML_USE_CUBLAS
LLAMA_LOG_INFO("%s: offloaded %d/%d layers to GPU\n", __func__, std::min(n_gpu_layers, max_offloadable_layers), max_backend_supported_layers); LLAMA_LOG_INFO("%s: offloaded %d/%d layers to GPU\n", __func__, std::min(n_gpu_layers, max_offloadable_layers), max_backend_supported_layers);
LLAMA_LOG_INFO("%s: VRAM used: %.2f MB\n", __func__, vram_weights / 1024.0 / 1024.0); LLAMA_LOG_INFO("%s: VRAM used: %.2f MiB\n", __func__, vram_weights / 1024.0 / 1024.0);
#else #else
(void) n_gpu_layers; (void) n_gpu_layers;
#endif // defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST) #endif // defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST)
@ -6484,7 +6586,10 @@ static std::vector<llama_vocab::id> llama_tokenize_internal(const llama_vocab &
// by modifying llm_tokenizer_x to operate with string offsets like pre-tokenizer // by modifying llm_tokenizer_x to operate with string offsets like pre-tokenizer
// and passing 'add space prefix' as bool argument // and passing 'add space prefix' as bool argument
// //
auto raw_text = (special ? "" : " ") + fragment.raw_text.substr(fragment.offset, fragment.length); auto raw_text = fragment.raw_text.substr(fragment.offset, fragment.length);
if (&fragment == &fragment_buffer.front()) {
raw_text = " " + raw_text; // prefix with space if the first token is not special
}
#ifdef PRETOKENIZERDEBUG #ifdef PRETOKENIZERDEBUG
fprintf(stderr,"TT: (%ld %ld %ld) '%s'\n", raw_text.length(), fragment.offset, fragment.length, raw_text.c_str()); fprintf(stderr,"TT: (%ld %ld %ld) '%s'\n", raw_text.length(), fragment.offset, fragment.length, raw_text.c_str());
@ -8136,7 +8241,7 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
workers.clear(); workers.clear();
} }
LLAMA_LOG_INFO("size = %8.2f MB -> %8.2f MB | hist: ", ggml_nbytes(tensor)/1024.0/1024.0, new_size/1024.0/1024.0); LLAMA_LOG_INFO("size = %8.2f MiB -> %8.2f MiB | hist: ", ggml_nbytes(tensor)/1024.0/1024.0, new_size/1024.0/1024.0);
int64_t tot_count = 0; int64_t tot_count = 0;
for (size_t i = 0; i < hist_cur.size(); i++) { for (size_t i = 0; i < hist_cur.size(); i++) {
hist_all[i] += hist_cur[i]; hist_all[i] += hist_cur[i];
@ -8677,7 +8782,7 @@ struct llama_context * llama_new_context_with_model(
{ {
const size_t memory_size = ggml_nbytes(ctx->kv_self.k) + ggml_nbytes(ctx->kv_self.v); const size_t memory_size = ggml_nbytes(ctx->kv_self.k) + ggml_nbytes(ctx->kv_self.v);
LLAMA_LOG_INFO("%s: kv self size = %7.2f MB\n", __func__, memory_size / 1024.0 / 1024.0); LLAMA_LOG_INFO("%s: kv self size = %7.2f MiB\n", __func__, memory_size / 1024.0 / 1024.0);
} }
// resized during inference // resized during inference
@ -8722,7 +8827,7 @@ struct llama_context * llama_new_context_with_model(
// measure memory requirements for the graph // measure memory requirements for the graph
size_t alloc_size = ggml_allocr_alloc_graph(ctx->alloc, gf) + tensor_alignment; size_t alloc_size = ggml_allocr_alloc_graph(ctx->alloc, gf) + tensor_alignment;
LLAMA_LOG_INFO("%s: compute buffer total size = %.2f MB\n", __func__, (ctx->buf_compute.size + alloc_size) / 1024.0 / 1024.0); LLAMA_LOG_INFO("%s: compute buffer total size = %.2f MiB\n", __func__, (ctx->buf_compute.size + alloc_size) / 1024.0 / 1024.0);
// recreate allocator with exact memory requirements // recreate allocator with exact memory requirements
ggml_allocr_free(ctx->alloc); ggml_allocr_free(ctx->alloc);
@ -8736,7 +8841,7 @@ struct llama_context * llama_new_context_with_model(
#endif #endif
#ifdef GGML_USE_CUBLAS #ifdef GGML_USE_CUBLAS
ggml_cuda_set_scratch_size(alloc_size); ggml_cuda_set_scratch_size(alloc_size);
LLAMA_LOG_INFO("%s: VRAM scratch buffer: %.2f MB\n", __func__, alloc_size / 1024.0 / 1024.0); LLAMA_LOG_INFO("%s: VRAM scratch buffer: %.2f MiB\n", __func__, alloc_size / 1024.0 / 1024.0);
// calculate total VRAM usage // calculate total VRAM usage
auto add_tensor = [](const ggml_tensor * t, size_t & size) { auto add_tensor = [](const ggml_tensor * t, size_t & size) {
@ -8756,7 +8861,7 @@ struct llama_context * llama_new_context_with_model(
size_t ctx_vram_size = alloc_size + kv_vram_size; size_t ctx_vram_size = alloc_size + kv_vram_size;
size_t total_vram_size = model_vram_size + ctx_vram_size; size_t total_vram_size = model_vram_size + ctx_vram_size;
LLAMA_LOG_INFO("%s: total VRAM used: %.2f MB (model: %.2f MB, context: %.2f MB)\n", __func__, LLAMA_LOG_INFO("%s: total VRAM used: %.2f MiB (model: %.2f MiB, context: %.2f MiB)\n", __func__,
total_vram_size / 1024.0 / 1024.0, total_vram_size / 1024.0 / 1024.0,
model_vram_size / 1024.0 / 1024.0, model_vram_size / 1024.0 / 1024.0,
ctx_vram_size / 1024.0 / 1024.0); ctx_vram_size / 1024.0 / 1024.0);
@ -8780,7 +8885,7 @@ struct llama_context * llama_new_context_with_model(
const size_t max_size = ggml_get_max_tensor_size(ctx->model.ctx); const size_t max_size = ggml_get_max_tensor_size(ctx->model.ctx);
LLAMA_LOG_INFO("%s: max tensor size = %8.2f MB\n", __func__, max_size/1024.0/1024.0); LLAMA_LOG_INFO("%s: max tensor size = %8.2f MiB\n", __func__, max_size/1024.0/1024.0);
#define LLAMA_METAL_CHECK_BUF(result) \ #define LLAMA_METAL_CHECK_BUF(result) \
if (!(result)) { \ if (!(result)) { \
@ -8846,6 +8951,45 @@ float llama_rope_freq_scale_train(const struct llama_model * model) {
return model->hparams.rope_freq_scale_train; return model->hparams.rope_freq_scale_train;
} }
int llama_model_meta_val_str(const struct llama_model * model, const char * key, char * buf, size_t buf_size) {
const auto & it = model->gguf_kv.find(key);
if (it == model->gguf_kv.end()) {
if (buf_size > 0) {
buf[0] = '\0';
}
return -1;
}
return snprintf(buf, buf_size, "%s", it->second.c_str());
}
int llama_model_meta_count(const struct llama_model * model) {
return (int)model->gguf_kv.size();
}
int llama_model_meta_key_by_index(const struct llama_model * model, int i, char * buf, size_t buf_size) {
if (i < 0 || i >= (int)model->gguf_kv.size()) {
if (buf_size > 0) {
buf[0] = '\0';
}
return -1;
}
auto it = model->gguf_kv.begin();
std::advance(it, i);
return snprintf(buf, buf_size, "%s", it->first.c_str());
}
int llama_model_meta_val_str_by_index(const struct llama_model * model, int i, char * buf, size_t buf_size) {
if (i < 0 || i >= (int)model->gguf_kv.size()) {
if (buf_size > 0) {
buf[0] = '\0';
}
return -1;
}
auto it = model->gguf_kv.begin();
std::advance(it, i);
return snprintf(buf, buf_size, "%s", it->second.c_str());
}
int llama_model_desc(const struct llama_model * model, char * buf, size_t buf_size) { int llama_model_desc(const struct llama_model * model, char * buf, size_t buf_size) {
return snprintf(buf, buf_size, "%s %s %s", return snprintf(buf, buf_size, "%s %s %s",
llama_model_arch_name(model->arch).c_str(), llama_model_arch_name(model->arch).c_str(),
@ -9487,6 +9631,14 @@ llama_token llama_token_nl(const struct llama_model * model) {
return model->vocab.linefeed_id; return model->vocab.linefeed_id;
} }
int llama_add_bos_token(const struct llama_model * model) {
return model->vocab.special_add_bos;
}
int llama_add_eos_token(const struct llama_model * model) {
return model->vocab.special_add_eos;
}
llama_token llama_token_prefix(const struct llama_model * model) { llama_token llama_token_prefix(const struct llama_model * model) {
return model->vocab.special_prefix_id; return model->vocab.special_prefix_id;
} }

23
llama.h
View file

@ -318,6 +318,23 @@ extern "C" {
// Get the model's RoPE frequency scaling factor // Get the model's RoPE frequency scaling factor
LLAMA_API float llama_rope_freq_scale_train(const struct llama_model * model); LLAMA_API float llama_rope_freq_scale_train(const struct llama_model * model);
// Functions to access the model's GGUF metadata scalar values
// - The functions return the length of the string on success, or -1 on failure
// - The output string is always null-terminated and cleared on failure
// - GGUF array values are not supported by these functions
// Get metadata value as a string by key name
LLAMA_API int llama_model_meta_val_str(const struct llama_model * model, const char * key, char * buf, size_t buf_size);
// Get the number of metadata key/value pairs
LLAMA_API int llama_model_meta_count(const struct llama_model * model);
// Get metadata key name by index
LLAMA_API int llama_model_meta_key_by_index(const struct llama_model * model, int i, char * buf, size_t buf_size);
// Get metadata value as a string by index
LLAMA_API int llama_model_meta_val_str_by_index(const struct llama_model * model, int i, char * buf, size_t buf_size);
// Get a string describing the model type // Get a string describing the model type
LLAMA_API int llama_model_desc(const struct llama_model * model, char * buf, size_t buf_size); LLAMA_API int llama_model_desc(const struct llama_model * model, char * buf, size_t buf_size);
@ -534,6 +551,12 @@ extern "C" {
LLAMA_API llama_token llama_token_eos(const struct llama_model * model); // end-of-sentence LLAMA_API llama_token llama_token_eos(const struct llama_model * model); // end-of-sentence
LLAMA_API llama_token llama_token_nl (const struct llama_model * model); // next-line LLAMA_API llama_token llama_token_nl (const struct llama_model * model); // next-line
// Returns -1 if unknown, 1 for true or 0 for false.
LLAMA_API int llama_add_bos_token(const struct llama_model * model);
// Returns -1 if unknown, 1 for true or 0 for false.
LLAMA_API int llama_add_eos_token(const struct llama_model * model);
// codellama infill tokens // codellama infill tokens
LLAMA_API llama_token llama_token_prefix(const struct llama_model * model); // Beginning of infill prefix LLAMA_API llama_token llama_token_prefix(const struct llama_model * model); // Beginning of infill prefix
LLAMA_API llama_token llama_token_middle(const struct llama_model * model); // Beginning of infill middle LLAMA_API llama_token llama_token_middle(const struct llama_model * model); // Beginning of infill middle

View file

@ -1,7 +1,5 @@
# tests with BPE tokenizer # tests with BPE tokenizer
import os
import sys
import argparse import argparse
from transformers import AutoTokenizer from transformers import AutoTokenizer

View file

@ -1,7 +1,5 @@
# tests with SPM tokenizer # tests with SPM tokenizer
import os
import sys
import argparse import argparse
from sentencepiece import SentencePieceProcessor from sentencepiece import SentencePieceProcessor