Merge branch 'master' into concedo_experimental
# Conflicts: # Makefile # README.md # ggml.h # llama.cpp
This commit is contained in:
commit
cde3760e52
13 changed files with 527 additions and 225 deletions
|
@ -120,6 +120,7 @@ if (LLAMA_ALL_WARNINGS)
|
|||
-Wshadow
|
||||
-Wstrict-prototypes
|
||||
-Wpointer-arith
|
||||
-Wmissing-prototypes
|
||||
)
|
||||
set(cxx_flags
|
||||
-Wall
|
||||
|
|
96
convert.py
Executable file → Normal file
96
convert.py
Executable file → Normal file
|
@ -133,7 +133,7 @@ TENSORS_SET = set(TENSORS_LIST)
|
|||
|
||||
def find_n_mult(n_ff: int, n_embd: int) -> int:
|
||||
# hardcoded magic range
|
||||
for n_mult in range(256, 1, -1):
|
||||
for n_mult in range(8192, 1, -1):
|
||||
calc_ff = (((8*n_embd) // 3 + n_mult - 1) // n_mult)*n_mult
|
||||
if calc_ff == n_ff:
|
||||
return n_mult
|
||||
|
@ -141,11 +141,12 @@ def find_n_mult(n_ff: int, n_embd: int) -> int:
|
|||
|
||||
@dataclass
|
||||
class Params:
|
||||
n_vocab: int
|
||||
n_embd: int
|
||||
n_mult: int
|
||||
n_head: int
|
||||
n_layer: int
|
||||
n_vocab: int
|
||||
n_embd: int
|
||||
n_mult: int
|
||||
n_head: int
|
||||
n_layer: int
|
||||
n_kv_head: Optional[int] # This parameter is only used for Llama 2
|
||||
|
||||
@staticmethod
|
||||
def guessed(model: 'LazyModel') -> 'Params':
|
||||
|
@ -167,11 +168,12 @@ class Params:
|
|||
n_head=n_embd // 128 # guessed
|
||||
|
||||
return Params(
|
||||
n_vocab = n_vocab,
|
||||
n_embd = n_embd,
|
||||
n_mult = 256,
|
||||
n_head = n_head,
|
||||
n_layer = n_layer,
|
||||
n_vocab = n_vocab,
|
||||
n_embd = n_embd,
|
||||
n_mult = 256,
|
||||
n_head = n_head,
|
||||
n_layer = n_layer,
|
||||
n_kv_head = None,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
|
@ -183,15 +185,17 @@ class Params:
|
|||
n_head = config["num_attention_heads"];
|
||||
n_layer = config["num_hidden_layers"];
|
||||
n_ff = config["intermediate_size"];
|
||||
n_kv_head = config.get("num_key_value_heads")
|
||||
|
||||
n_mult = find_n_mult(n_ff, n_embd);
|
||||
|
||||
return Params(
|
||||
n_vocab = n_vocab,
|
||||
n_embd = n_embd,
|
||||
n_mult = n_mult,
|
||||
n_head = n_head,
|
||||
n_layer = n_layer,
|
||||
n_vocab = n_vocab,
|
||||
n_embd = n_embd,
|
||||
n_mult = n_mult,
|
||||
n_head = n_head,
|
||||
n_layer = n_layer,
|
||||
n_kv_head = n_kv_head,
|
||||
)
|
||||
|
||||
# LLaMA v2 70B params.json
|
||||
|
@ -200,21 +204,22 @@ class Params:
|
|||
def loadOriginalParamsJson(model: 'LazyModel', config_path: 'Path') -> 'Params':
|
||||
config = json.load(open(config_path))
|
||||
|
||||
n_vocab = config["vocab_size"];
|
||||
n_embd = config["dim"];
|
||||
n_head = config["n_heads"];
|
||||
n_layer = config["n_layers"];
|
||||
n_mult = config["multiple_of"];
|
||||
n_vocab = config["vocab_size"];
|
||||
n_embd = config["dim"];
|
||||
n_head = config["n_heads"];
|
||||
n_layer = config["n_layers"];
|
||||
n_mult = config["multiple_of"];
|
||||
|
||||
if n_vocab == -1:
|
||||
n_vocab = model["tok_embeddings.weight"].shape[0]
|
||||
|
||||
return Params(
|
||||
n_vocab = n_vocab,
|
||||
n_embd = n_embd,
|
||||
n_mult = n_mult,
|
||||
n_head = n_head,
|
||||
n_layer = n_layer,
|
||||
n_vocab = n_vocab,
|
||||
n_embd = n_embd,
|
||||
n_mult = n_mult,
|
||||
n_head = n_head,
|
||||
n_layer = n_layer,
|
||||
n_kv_head = None,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
|
@ -317,10 +322,12 @@ class GGMLVocab:
|
|||
Vocab = Union[SentencePieceVocab, GGMLVocab]
|
||||
|
||||
|
||||
def permute(weights: NDArray, n_head: int) -> NDArray:
|
||||
def permute(weights: NDArray, n_head: int, n_kv_head: Optional[int] = None) -> NDArray:
|
||||
if n_kv_head is not None and n_head != n_kv_head:
|
||||
n_head //= n_kv_head
|
||||
return (weights.reshape(n_head, 2, weights.shape[0] // n_head // 2, *weights.shape[1:])
|
||||
.swapaxes(1, 2)
|
||||
.reshape(weights.shape))
|
||||
.swapaxes(1, 2)
|
||||
.reshape(weights.shape))
|
||||
|
||||
|
||||
def dequantize_q4(qvalues_pack32: NDArray, scales: NDArray, addends: Optional[NDArray], g_idx: Optional[NDArray]) -> NDArray:
|
||||
|
@ -368,7 +375,7 @@ class Tensor(metaclass=ABCMeta):
|
|||
@abstractmethod
|
||||
def astype(self, data_type: DataType) -> 'Tensor': ...
|
||||
@abstractmethod
|
||||
def permute(self, n_head: int) -> 'Tensor': ...
|
||||
def permute(self, n_head: int, n_kv_head: Optional[int] = None) -> 'Tensor': ...
|
||||
@abstractmethod
|
||||
def permute_part(self, n_part: int, n_head: int) -> 'UnquantizedTensor': ...
|
||||
@abstractmethod
|
||||
|
@ -406,8 +413,8 @@ class UnquantizedTensor(Tensor):
|
|||
r = self.ndarray.shape[0] // 3
|
||||
return UnquantizedTensor(self.ndarray[r * n_part : r * n_part + r, ...])
|
||||
|
||||
def permute(self, n_head: int) -> 'UnquantizedTensor':
|
||||
return UnquantizedTensor(permute(self.ndarray, n_head))
|
||||
def permute(self, n_head: int, n_kv_head: Optional[int] = None) -> 'UnquantizedTensor':
|
||||
return UnquantizedTensor(permute(self.ndarray, n_head, n_kv_head))
|
||||
|
||||
|
||||
def load_unquantized(lazy_tensor: 'LazyTensor', expected_dtype: Any = None, convert: bool = False) -> NDArray:
|
||||
|
@ -455,26 +462,27 @@ class GGMLQuantizedTensor(Tensor):
|
|||
def to_ggml(self) -> 'GGMLQuantizedTensor':
|
||||
return self
|
||||
|
||||
def permute(self, n_head: int) -> 'GGMLQuantizedTensor':
|
||||
return GGMLQuantizedTensor(permute(self.ndarray, n_head), self.shape, self.data_type)
|
||||
def permute(self, n_head: int, n_kv_head: Optional[int] = None) -> 'GGMLQuantizedTensor':
|
||||
return GGMLQuantizedTensor(permute(self.ndarray, n_head, n_kv_head), self.shape, self.data_type)
|
||||
|
||||
|
||||
GGMLCompatibleTensor = Union[UnquantizedTensor, GGMLQuantizedTensor]
|
||||
|
||||
|
||||
class DeferredPermutedTensor(Tensor):
|
||||
def __init__(self, base: Tensor, n_head: int) -> None:
|
||||
def __init__(self, base: Tensor, n_head: int, n_kv_head: Optional[int] = None) -> None:
|
||||
self.base = base
|
||||
self.n_head = n_head
|
||||
self.n_kv_head = n_kv_head
|
||||
self.data_type = self.base.data_type
|
||||
|
||||
def astype(self, data_type: DataType) -> Tensor:
|
||||
return self.base.astype(data_type).permute(self.n_head)
|
||||
return self.base.astype(data_type).permute(self.n_head, self.n_kv_head)
|
||||
|
||||
def to_ggml(self) -> GGMLCompatibleTensor:
|
||||
return self.base.to_ggml().permute(self.n_head)
|
||||
return self.base.to_ggml().permute(self.n_head, self.n_kv_head)
|
||||
|
||||
def permute(self, n_head: int) -> Tensor:
|
||||
def permute(self, n_head: int, n_kv_head: Optional[int] = None) -> Tensor:
|
||||
raise Exception("shouldn't permute twice")
|
||||
|
||||
|
||||
|
@ -566,8 +574,8 @@ class GPTQForLLaMaQuantizedTensor(Tensor):
|
|||
ret.data_type = QuantizedDataType(groupsize=new_groupsize, have_addends=True, have_g_idx=False)
|
||||
return ret
|
||||
|
||||
def permute(self, n_head: int) -> Tensor:
|
||||
return DeferredPermutedTensor(self, n_head)
|
||||
def permute(self, n_head: int, n_kv_head: Optional[int] = None) -> Tensor:
|
||||
return DeferredPermutedTensor(self, n_head, n_kv_head)
|
||||
|
||||
def to_ggml(self) -> GGMLQuantizedTensor:
|
||||
# The output format looks like this:
|
||||
|
@ -698,10 +706,10 @@ def merge_multifile_models(models_plus: List[ModelPlus]) -> ModelPlus:
|
|||
return ModelPlus(model, paths, format, vocab)
|
||||
|
||||
|
||||
def permute_lazy(lazy_tensor: LazyTensor, n_head: int) -> LazyTensor:
|
||||
def permute_lazy(lazy_tensor: LazyTensor, n_head: int, n_kv_head: Optional[int] = None) -> LazyTensor:
|
||||
def load() -> Tensor:
|
||||
return lazy_tensor.load().permute(n_head)
|
||||
return LazyTensor(load, lazy_tensor.shape, lazy_tensor.data_type, f'permute({n_head}) ' + lazy_tensor.description)
|
||||
return lazy_tensor.load().permute(n_head, n_kv_head)
|
||||
return LazyTensor(load, lazy_tensor.shape, lazy_tensor.data_type, f'permute({n_head}, {n_kv_head}) ' + lazy_tensor.description)
|
||||
|
||||
def permute_part_lazy(lazy_tensor: LazyTensor, n_part: int, n_head: int) -> LazyTensor:
|
||||
def load() -> Tensor:
|
||||
|
@ -726,7 +734,7 @@ def convert_transformers_to_orig(model: LazyModel, params: Params) -> LazyModel:
|
|||
for i in itertools.count():
|
||||
if f"model.layers.{i}.self_attn.q_proj.weight" in model:
|
||||
out[f"layers.{i}.attention.wq.weight"] = permute_lazy(model[f"model.layers.{i}.self_attn.q_proj.weight"], params.n_head)
|
||||
out[f"layers.{i}.attention.wk.weight"] = permute_lazy(model[f"model.layers.{i}.self_attn.k_proj.weight"], params.n_head)
|
||||
out[f"layers.{i}.attention.wk.weight"] = permute_lazy(model[f"model.layers.{i}.self_attn.k_proj.weight"], params.n_head, params.n_kv_head)
|
||||
out[f"layers.{i}.attention.wv.weight"] = model[f"model.layers.{i}.self_attn.v_proj.weight"]
|
||||
elif f"model.layers.{i}.self_attn.W_pack.weight" in model:
|
||||
out[f"layers.{i}.attention.wq.weight"] = permute_part_lazy(model[f"model.layers.{i}.self_attn.W_pack.weight"], 0, params.n_head)
|
||||
|
|
|
@ -402,8 +402,14 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
|
|||
params.antiprompt.push_back(argv[i]);
|
||||
} else if (arg == "--perplexity") {
|
||||
params.perplexity = true;
|
||||
} else if (arg == "--perplexity-lines") {
|
||||
params.perplexity_lines = true;
|
||||
} else if (arg == "--hellaswag") {
|
||||
params.hellaswag = true;
|
||||
} else if (arg == "--hellaswag-tasks") {
|
||||
if (++i >= argc) {
|
||||
invalid_param = true;
|
||||
break;
|
||||
}
|
||||
params.hellaswag_tasks = std::stoi(argv[i]);
|
||||
} else if (arg == "--ignore-eos") {
|
||||
params.logit_bias[llama_token_eos()] = -INFINITY;
|
||||
} else if (arg == "--no-penalize-nl") {
|
||||
|
@ -559,8 +565,9 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
|
|||
fprintf(stdout, " not recommended: doubles context memory required and no measurable increase in quality\n");
|
||||
fprintf(stdout, " --temp N temperature (default: %.1f)\n", (double)params.temp);
|
||||
fprintf(stdout, " --perplexity compute perplexity over each ctx window of the prompt\n");
|
||||
fprintf(stdout, " --perplexity-lines compute perplexity over each line of the prompt\n");
|
||||
fprintf(stdout, " --keep number of tokens to keep from the initial prompt (default: %d, -1 = all)\n", params.n_keep);
|
||||
fprintf(stdout, " --hellaswag compute HellaSwag score over random tasks from datafile supplied with -f\n");
|
||||
fprintf(stdout, " --hellaswag-tasks N number of tasks to use when computing the HellaSwag score (default: %d)\n", params.hellaswag_tasks);
|
||||
fprintf(stdout, " --keep N number of tokens to keep from the initial prompt (default: %d, -1 = all)\n", params.n_keep);
|
||||
fprintf(stdout, " --chunks N max number of chunks to process (default: %d, -1 = all)\n", params.n_chunks);
|
||||
if (llama_mlock_supported()) {
|
||||
fprintf(stdout, " --mlock force system to keep model in RAM rather than swapping or compressing\n");
|
||||
|
|
|
@ -70,7 +70,10 @@ struct gpt_params {
|
|||
std::string lora_adapter = ""; // lora adapter path
|
||||
std::string lora_base = ""; // base model path for the lora adapter
|
||||
|
||||
bool low_vram = false; // if true, reduce VRAM usage at the cost of performance
|
||||
bool hellaswag = false; // compute HellaSwag score over random tasks from datafile supplied in prompt
|
||||
size_t hellaswag_tasks = 400; // number of tasks to use when computing the HellaSwag score
|
||||
|
||||
bool low_vram = false; // if true, reduce VRAM usage at the cost of performance
|
||||
bool memory_f16 = true; // use f16 instead of f32 for memory kv
|
||||
bool random_prompt = false; // do not randomize prompt if none provided
|
||||
bool use_color = false; // use color to distinguish generations and inputs
|
||||
|
@ -86,7 +89,6 @@ struct gpt_params {
|
|||
bool instruct = false; // instruction mode (used for Alpaca models)
|
||||
bool penalize_nl = true; // consider newlines as a repeatable token
|
||||
bool perplexity = false; // compute perplexity over the prompt
|
||||
bool perplexity_lines = false; // compute perplexity over each line of the prompt
|
||||
bool use_mmap = true; // use mmap for faster loads
|
||||
bool use_mlock = false; // use mlock to keep model in memory
|
||||
bool mem_test = false; // compute maximum memory usage
|
||||
|
|
|
@ -202,9 +202,9 @@ Example usage: `--top-p 0.95`
|
|||
|
||||
- `--tfs N`: Enable tail free sampling with parameter z (default: 1.0, 1.0 = disabled).
|
||||
|
||||
Tail free sampling (TFS) is a text generation technique that aims to reduce the impact of less likely tokens, which may be less relevant, less coherent, or nonsensical, on the output. The method adjusts the logits (token probabilities) by raising them to the power of the parameter z. A higher value of z (e.g., 2.0) will further suppress less likely tokens from the tail of the distribution, while a value of 1.0 disables the effect of TFS. By setting the parameter z, you can control how much the probabilities of less likely tokens are reduced.
|
||||
Tail free sampling (TFS) is a text generation technique that aims to reduce the impact of less likely tokens, which may be less relevant, less coherent, or nonsensical, on the output. Similar to Top-P it tries to determine the bulk of the most likely tokens dynamically. But TFS filters out logits based on the second derivative of their probabilities. Adding tokens is stopped after the sum of the second derivatives reaches the parameter z. In short: TFS looks how quickly the probabilities of the tokens decrease and cuts off the tail of unlikely tokens using the parameter z. Typical values for z are in the range of 0.9 to 0.95. A value of 1.0 would include all tokens, and thus disables the effect of TFS.
|
||||
|
||||
Example usage: `--tfs 2.0`
|
||||
Example usage: `--tfs 0.95`
|
||||
|
||||
### Locally Typical Sampling
|
||||
|
||||
|
|
|
@ -121,8 +121,23 @@ void perplexity(llama_context * ctx, const gpt_params & params) {
|
|||
printf("\n");
|
||||
}
|
||||
|
||||
void perplexity_lines(llama_context * ctx, const gpt_params & params) {
|
||||
// Calculates perplexity over each line of the prompt
|
||||
void hellaswag_score(llama_context * ctx, const gpt_params & params) {
|
||||
// Calculates hellaswag score (acc_norm) from prompt
|
||||
//
|
||||
// Data extracted from the HellaSwag validation dataset (MIT license) https://github.com/rowanz/hellaswag/blob/master/data/hellaswag_val.jsonl
|
||||
// All used data fields are preprocessed as in https://github.com/EleutherAI/lm-evaluation-harness/blob/df3da98c5405deafd519c2ddca52bb7c3fe36bef/lm_eval/tasks/hellaswag.py#L62-L68
|
||||
//
|
||||
// All 10042 tasks should be extracted to keep the results standardized like other implementations.
|
||||
//
|
||||
// Datafile layout:
|
||||
// ['??'] denotes json fields
|
||||
// 6 lines per task:
|
||||
// ['activity_label'] + ": " +['ctx'] - The first part of the query, the context
|
||||
// ['label'] - The index the best common sense ending aka gold ending
|
||||
// ['endings'][0] - Endings added to the first part of the query
|
||||
// ['endings'][1]
|
||||
// ['endings'][2]
|
||||
// ['endings'][3]
|
||||
|
||||
std::vector<std::string> prompt_lines;
|
||||
std::istringstream strstream(params.prompt);
|
||||
|
@ -132,63 +147,149 @@ void perplexity_lines(llama_context * ctx, const gpt_params & params) {
|
|||
prompt_lines.push_back(line);
|
||||
}
|
||||
|
||||
if( prompt_lines.size() % 6 != 0) {
|
||||
fprintf(stderr, "%s : number of lines in prompt not a multiple of 6.\n", __func__);
|
||||
return;
|
||||
}
|
||||
|
||||
size_t hs_task_count = prompt_lines.size()/6;
|
||||
fprintf(stderr, "%s : loaded %lu tasks from prompt.\n", __func__, hs_task_count);
|
||||
|
||||
// This is needed as usual for LLaMA models
|
||||
bool prepend_bos = true;
|
||||
|
||||
// Number of tasks to use when computing the score
|
||||
if ( params.hellaswag_tasks < hs_task_count ) {
|
||||
hs_task_count = params.hellaswag_tasks;
|
||||
}
|
||||
|
||||
// The tasks should be randomized so the score stabilizes quickly.
|
||||
bool randomize_tasks = true;
|
||||
|
||||
// The random seed should not impact the final result if the computation is done over enough tasks, so kept hardcoded for now
|
||||
std::mt19937 rng(1);
|
||||
|
||||
// Dataholder for hellaswag tasks
|
||||
struct hs_data_t {
|
||||
std::string context;
|
||||
size_t gold_ending_idx;
|
||||
std::string ending[4];
|
||||
size_t ending_logprob_count[4];
|
||||
double ending_logprob[4];
|
||||
};
|
||||
|
||||
fprintf(stderr, "%s : selecting %lu %s tasks.\n", __func__, hs_task_count, (randomize_tasks?"randomized":"the first") );
|
||||
|
||||
// Select and read data from prompt lines
|
||||
hs_data_t *hs_data = new hs_data_t[hs_task_count];
|
||||
for (size_t i=0; i < hs_task_count; i++) {
|
||||
size_t idx = i;
|
||||
|
||||
// Select a random example of those left in the prompt
|
||||
if (randomize_tasks) {
|
||||
std::uniform_int_distribution<size_t> dist(0, prompt_lines.size()/6-1 ) ;
|
||||
idx = dist(rng);
|
||||
}
|
||||
|
||||
hs_data[i].context = prompt_lines[idx*6];
|
||||
hs_data[i].gold_ending_idx = std::stoi( prompt_lines[idx*6+1] );
|
||||
for (size_t j=0; j < 4; j++) {
|
||||
hs_data[i].ending[j] = " " + prompt_lines[idx*6+2+j];
|
||||
}
|
||||
|
||||
// Delete the selected random example from the prompt
|
||||
if (randomize_tasks) {
|
||||
prompt_lines.erase( std::next(prompt_lines.begin(),idx*6) , std::next(prompt_lines.begin(),idx*6+6) );
|
||||
}
|
||||
}
|
||||
|
||||
fprintf(stderr, "%s : calculating hellaswag score over selected tasks.\n", __func__);
|
||||
printf("\ntask\tacc_norm\n");
|
||||
|
||||
double acc = 0.0f;
|
||||
const int n_vocab = llama_n_vocab(ctx);
|
||||
|
||||
int counttotal = 0;
|
||||
size_t n_lines = prompt_lines.size();
|
||||
for (size_t task_idx = 0; task_idx < hs_task_count; task_idx++) {
|
||||
|
||||
double nll = 0.0;
|
||||
// Tokenize the context to count tokens
|
||||
std::vector<int> context_embd = ::llama_tokenize(ctx, hs_data[task_idx].context, prepend_bos);
|
||||
size_t context_size = context_embd.size();
|
||||
|
||||
fprintf(stderr, "%s: calculating perplexity over %lu lines\n", __func__, n_lines);
|
||||
for (size_t ending_idx=0;ending_idx<4;ending_idx++) {
|
||||
|
||||
printf("\nLine\tPPL line\tPPL cumulative\n");
|
||||
// Tokenize the query
|
||||
std::vector<int> query_embd = ::llama_tokenize(ctx, hs_data[task_idx].context + hs_data[task_idx].ending[ending_idx], prepend_bos);
|
||||
size_t query_size = query_embd.size();
|
||||
|
||||
for (size_t i = 0; i < n_lines; ++i) {
|
||||
// Stop if query wont fit the ctx window
|
||||
if (query_size > (size_t)params.n_ctx) {
|
||||
fprintf(stderr, "%s : number of tokens in query %lu > n_ctxl\n", __func__, query_size);
|
||||
return;
|
||||
}
|
||||
|
||||
// Tokenize and insert BOS at start
|
||||
std::vector<int> batch_embd = ::llama_tokenize(ctx, prompt_lines[i], true);
|
||||
// Speedup small evaluations by evaluating atleast 32 tokens
|
||||
if (query_size < 32) {
|
||||
query_embd.resize(32);
|
||||
}
|
||||
|
||||
size_t batch_size = batch_embd.size();
|
||||
// Evaluate the query
|
||||
if (llama_eval(ctx, query_embd.data(), query_embd.size(), 0, params.n_threads)) {
|
||||
fprintf(stderr, "%s : failed to eval\n", __func__);
|
||||
return;
|
||||
}
|
||||
|
||||
// Stop if line is too long
|
||||
if( batch_size > (size_t)params.n_ctx ) {
|
||||
fprintf(stderr, "%s : tokens in line %lu > n_ctxl\n", __func__, i);
|
||||
return;
|
||||
const auto query_logits = llama_get_logits(ctx);
|
||||
std::vector<float> logits;
|
||||
logits.insert(logits.end(), query_logits, query_logits + query_size * n_vocab);
|
||||
|
||||
hs_data[task_idx].ending_logprob_count[ending_idx] = 0;
|
||||
hs_data[task_idx].ending_logprob[ending_idx] = 0.0f;
|
||||
|
||||
// Calculate the logprobs over the ending
|
||||
for (size_t j = context_size-1; j < query_size - 1; j++) {
|
||||
// Calculate probability of next token, given the previous ones.
|
||||
const std::vector<float> tok_logits(
|
||||
logits.begin() + (j + 0) * n_vocab,
|
||||
logits.begin() + (j + 1) * n_vocab);
|
||||
|
||||
const float prob = softmax(tok_logits)[query_embd[ j + 1]];
|
||||
|
||||
hs_data[task_idx].ending_logprob[ending_idx] += std::log(prob);
|
||||
hs_data[task_idx].ending_logprob_count[ending_idx]++;
|
||||
}
|
||||
|
||||
// Calculate the mean token logprob for acc_norm
|
||||
hs_data[task_idx].ending_logprob[ending_idx] /= hs_data[task_idx].ending_logprob_count[ending_idx];
|
||||
|
||||
|
||||
// printf("task %lu, ending %lu, whole_len %lu, context_len %lu, ending_logprob_count %lu, ending_logprob %.4f\n",
|
||||
// task_idx,ending_idx,whole_size,context_size, hs_data[task_idx].ending_logprob_count[ending_idx], hs_data[task_idx].ending_logprob[ending_idx] );
|
||||
}
|
||||
|
||||
if (llama_eval(ctx, batch_embd.data(), batch_size, 0, params.n_threads)) {
|
||||
fprintf(stderr, "%s : failed to eval\n", __func__);
|
||||
return;
|
||||
// Find the ending with maximum logprob
|
||||
size_t ending_logprob_max_idx = -1;
|
||||
double ending_logprob_max_val = -INFINITY;
|
||||
for (size_t j=0; j < 4; j++) {
|
||||
if (hs_data[task_idx].ending_logprob[j] > ending_logprob_max_val) {
|
||||
ending_logprob_max_idx = j;
|
||||
ending_logprob_max_val = hs_data[task_idx].ending_logprob[j];
|
||||
}
|
||||
}
|
||||
|
||||
const auto batch_logits = llama_get_logits(ctx);
|
||||
std::vector<float> logits;
|
||||
logits.insert(logits.end(), batch_logits, batch_logits + batch_size * n_vocab);
|
||||
// printf("max logprob ending idx %lu, gold ending idx %lu\n", ending_logprob_max_idx, hs_data[task_idx].gold_ending_idx);
|
||||
|
||||
double nllline = 0.0;
|
||||
int countline = 0;
|
||||
|
||||
// Perplexity over second half of the line
|
||||
for (size_t j = batch_size/2; j < batch_size - 1; ++j) {
|
||||
// Calculate probability of next token, given the previous ones.
|
||||
const std::vector<float> tok_logits(
|
||||
logits.begin() + (j + 0) * n_vocab,
|
||||
logits.begin() + (j + 1) * n_vocab);
|
||||
|
||||
const float prob = softmax(tok_logits)[batch_embd[ j + 1]];
|
||||
|
||||
nllline += -std::log(prob);
|
||||
++countline;
|
||||
// If the gold ending got the maximum logprobe add one accuracy point
|
||||
if (ending_logprob_max_idx == hs_data[task_idx].gold_ending_idx) {
|
||||
acc += 1.0;
|
||||
}
|
||||
|
||||
nll += nllline;
|
||||
counttotal += countline;
|
||||
|
||||
// perplexity is e^(average negative log-likelihood)
|
||||
printf("%lu\t%.8lf\t%.8lf\n", i + 1, std::exp(nllline/countline), std::exp(nll / counttotal) );
|
||||
// Print the accumulated accuracy mean x 100
|
||||
printf("%li\t%.8lf\n",task_idx+1, acc/double(task_idx+1)*100.0);
|
||||
fflush(stdout);
|
||||
}
|
||||
|
||||
delete [] hs_data;
|
||||
|
||||
printf("\n");
|
||||
}
|
||||
|
||||
|
@ -240,8 +341,8 @@ int main(int argc, char ** argv) {
|
|||
params.n_threads, std::thread::hardware_concurrency(), llama_print_system_info());
|
||||
}
|
||||
|
||||
if (params.perplexity_lines) {
|
||||
perplexity_lines(ctx, params);
|
||||
if (params.hellaswag) {
|
||||
hellaswag_score(ctx, params);
|
||||
} else {
|
||||
perplexity(ctx, params);
|
||||
}
|
||||
|
|
|
@ -26,6 +26,7 @@ int main(int argc, char ** argv) {
|
|||
auto lparams = llama_context_default_params();
|
||||
|
||||
lparams.n_ctx = params.n_ctx;
|
||||
lparams.n_gqa = params.n_gqa;
|
||||
lparams.seed = params.seed;
|
||||
lparams.f16_kv = params.memory_f16;
|
||||
lparams.use_mmap = params.use_mmap;
|
||||
|
|
26
examples/server-llama2-13B.sh
Normal file
26
examples/server-llama2-13B.sh
Normal file
|
@ -0,0 +1,26 @@
|
|||
#!/bin/bash
|
||||
|
||||
set -e
|
||||
|
||||
cd "$(dirname "$0")/.." || exit
|
||||
|
||||
# Specify the model you want to use here:
|
||||
MODEL="${MODEL:-./models/llama-2-13b-chat.ggmlv3.q5_K_M.bin}"
|
||||
PROMPT_TEMPLATE=${PROMPT_TEMPLATE:-./prompts/chat-system.txt}
|
||||
|
||||
# Adjust to the number of CPU cores you want to use.
|
||||
N_THREAD="${N_THREAD:-12}"
|
||||
|
||||
# Note: you can also override the generation options by specifying them on the command line:
|
||||
GEN_OPTIONS="${GEN_OPTIONS:---ctx_size 4096 --batch-size 1024}"
|
||||
|
||||
|
||||
# shellcheck disable=SC2086 # Intended splitting of GEN_OPTIONS
|
||||
./server $GEN_OPTIONS \
|
||||
--model "$MODEL" \
|
||||
--threads "$N_THREAD" \
|
||||
--rope-freq-scale 1.0 \
|
||||
"$@"
|
||||
|
||||
# I used this to test the model with mps, but omitted it from the general purpose. If you want to use it, just specify it on the command line.
|
||||
# -ngl 1 \
|
109
examples/server/chat-llama2.sh
Normal file
109
examples/server/chat-llama2.sh
Normal file
|
@ -0,0 +1,109 @@
|
|||
#!/bin/bash
|
||||
|
||||
API_URL="${API_URL:-http://127.0.0.1:8080}"
|
||||
|
||||
CHAT=(
|
||||
"Hello, Assistant."
|
||||
"Hello. How may I help you today?"
|
||||
)
|
||||
|
||||
INSTRUCTION="A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions."
|
||||
|
||||
trim() {
|
||||
shopt -s extglob
|
||||
set -- "${1##+([[:space:]])}"
|
||||
printf "%s" "${1%%+([[:space:]])}"
|
||||
}
|
||||
|
||||
trim_trailing() {
|
||||
shopt -s extglob
|
||||
printf "%s" "${1%%+([[:space:]])}"
|
||||
}
|
||||
|
||||
format_prompt() {
|
||||
if [[ "${#CHAT[@]}" -eq 0 ]]; then
|
||||
echo -n "[INST] <<SYS>>\n${INSTRUCTION}\n<</SYS>>"
|
||||
else
|
||||
LAST_INDEX=$(( ${#CHAT[@]} - 1 ))
|
||||
echo -n "${CHAT[$LAST_INDEX]}\n[INST] $1 [/INST]"
|
||||
fi
|
||||
}
|
||||
|
||||
tokenize() {
|
||||
curl \
|
||||
--silent \
|
||||
--request POST \
|
||||
--url "${API_URL}/tokenize" \
|
||||
--header "Content-Type: application/json" \
|
||||
--data-raw "$(jq -ns --arg content "$1" '{content:$content}')" \
|
||||
| jq '.tokens[]'
|
||||
}
|
||||
|
||||
N_KEEP=$(tokenize "[INST] <<SYS>>\n${INSTRUCTION}\n<</SYS>>" | wc -l)
|
||||
|
||||
chat_completion() {
|
||||
PROMPT="$(trim_trailing "$(format_prompt "$1")")"
|
||||
DATA="$(echo -n "$PROMPT" | jq -Rs --argjson n_keep $N_KEEP '{
|
||||
prompt: .,
|
||||
temperature: 0.2,
|
||||
top_k: 40,
|
||||
top_p: 0.9,
|
||||
n_keep: $n_keep,
|
||||
n_predict: 1024,
|
||||
stop: ["[INST]"],
|
||||
stream: true
|
||||
}')"
|
||||
|
||||
# Create a temporary file to hold the Python output
|
||||
TEMPFILE=$(mktemp)
|
||||
|
||||
exec 3< <(curl \
|
||||
--silent \
|
||||
--no-buffer \
|
||||
--request POST \
|
||||
--url "${API_URL}/completion" \
|
||||
--header "Content-Type: application/json" \
|
||||
--data-raw "${DATA}")
|
||||
|
||||
python -c "
|
||||
import json
|
||||
import sys
|
||||
|
||||
answer = ''
|
||||
while True:
|
||||
line = sys.stdin.readline()
|
||||
if not line:
|
||||
break
|
||||
if line.startswith('data: '):
|
||||
json_content = line[6:].strip()
|
||||
content = json.loads(json_content)['content']
|
||||
sys.stdout.write(content)
|
||||
sys.stdout.flush()
|
||||
answer += content
|
||||
|
||||
answer = answer.rstrip('\n')
|
||||
|
||||
# Write the answer to the temporary file
|
||||
with open('$TEMPFILE', 'w') as f:
|
||||
f.write(answer)
|
||||
" <&3
|
||||
|
||||
exec 3<&-
|
||||
|
||||
# Read the answer from the temporary file
|
||||
ANSWER=$(cat $TEMPFILE)
|
||||
|
||||
# Clean up the temporary file
|
||||
rm $TEMPFILE
|
||||
|
||||
printf "\n"
|
||||
|
||||
CHAT+=("$1" "$(trim "$ANSWER")")
|
||||
}
|
||||
|
||||
while true; do
|
||||
echo -en "\033[0;32m" # Green color
|
||||
read -r -e -p "> " QUESTION
|
||||
echo -en "\033[0m" # Reset color
|
||||
chat_completion "${QUESTION}"
|
||||
done
|
182
ggml.c
182
ggml.c
|
@ -4072,8 +4072,8 @@ bool ggml_is_numa(void) {
|
|||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
void ggml_print_object(const struct ggml_object * obj) {
|
||||
GGML_PRINT(" - ggml_object: offset = %zu, size = %zu, next = %p\n",
|
||||
obj->offs, obj->size, (const void *) obj->next);
|
||||
GGML_PRINT(" - ggml_object: type = %d, offset = %zu, size = %zu, next = %p\n",
|
||||
obj->type, obj->offs, obj->size, (const void *) obj->next);
|
||||
}
|
||||
|
||||
void ggml_print_objects(const struct ggml_context * ctx) {
|
||||
|
@ -4213,7 +4213,7 @@ enum ggml_type ggml_ftype_to_ggml_type(enum ggml_ftype ftype) {
|
|||
}
|
||||
|
||||
size_t ggml_tensor_overhead(void) {
|
||||
return GGML_OBJECT_SIZE + GGML_TENSOR_SIZE + 16;
|
||||
return GGML_OBJECT_SIZE + GGML_TENSOR_SIZE;
|
||||
}
|
||||
|
||||
bool ggml_is_transposed(const struct ggml_tensor * tensor) {
|
||||
|
@ -4384,7 +4384,7 @@ struct ggml_context * ggml_init(struct ggml_init_params params) {
|
|||
return NULL;
|
||||
}
|
||||
|
||||
const size_t mem_size = (params.mem_size + GGML_MEM_ALIGN - 1) & ~(GGML_MEM_ALIGN - 1);
|
||||
const size_t mem_size = params.mem_buffer ? params.mem_size : GGML_PAD(params.mem_size, GGML_MEM_ALIGN);
|
||||
|
||||
*ctx = (struct ggml_context) {
|
||||
/*.mem_size =*/ mem_size,
|
||||
|
@ -4473,12 +4473,14 @@ size_t ggml_get_max_tensor_size(const struct ggml_context * ctx) {
|
|||
struct ggml_object * obj = ctx->objects_begin;
|
||||
|
||||
while (obj != NULL) {
|
||||
struct ggml_tensor * tensor = (struct ggml_tensor *) ((char *) ctx->mem_buffer + obj->offs);
|
||||
if (obj->type == GGML_OBJECT_TENSOR) {
|
||||
struct ggml_tensor * tensor = (struct ggml_tensor *) ((char *) ctx->mem_buffer + obj->offs);
|
||||
|
||||
const size_t size = ggml_nbytes(tensor);
|
||||
const size_t size = ggml_nbytes(tensor);
|
||||
|
||||
if (max_size < size) {
|
||||
max_size = size;
|
||||
if (max_size < size) {
|
||||
max_size = size;
|
||||
}
|
||||
}
|
||||
|
||||
obj = obj->next;
|
||||
|
@ -4510,12 +4512,7 @@ static void ggml_scratch_load(struct ggml_context * ctx) {
|
|||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
static struct ggml_tensor * ggml_new_tensor_impl(
|
||||
struct ggml_context * ctx,
|
||||
enum ggml_type type,
|
||||
int n_dims,
|
||||
const int64_t* ne,
|
||||
void* data) {
|
||||
static struct ggml_object * ggml_new_object(struct ggml_context * ctx, enum ggml_object_type type, size_t size) {
|
||||
// always insert objects at the end of the context's memory pool
|
||||
struct ggml_object * obj_cur = ctx->objects_end;
|
||||
|
||||
|
@ -4523,63 +4520,28 @@ static struct ggml_tensor * ggml_new_tensor_impl(
|
|||
const size_t cur_size = obj_cur == NULL ? 0 : obj_cur->size;
|
||||
const size_t cur_end = cur_offs + cur_size;
|
||||
|
||||
size_t size_needed = 0;
|
||||
|
||||
if (data == NULL && !ctx->no_alloc) {
|
||||
size_needed += GGML_TYPE_SIZE[type]*(ne[0]/GGML_BLCK_SIZE[type]);
|
||||
for (int i = 1; i < n_dims; i++) {
|
||||
size_needed *= ne[i];
|
||||
}
|
||||
// align to GGML_MEM_ALIGN
|
||||
size_needed = ((size_needed + GGML_MEM_ALIGN - 1)/GGML_MEM_ALIGN)*GGML_MEM_ALIGN;
|
||||
}
|
||||
// align to GGML_MEM_ALIGN
|
||||
size_t size_needed = GGML_PAD(size, GGML_MEM_ALIGN);
|
||||
|
||||
char * const mem_buffer = ctx->mem_buffer;
|
||||
struct ggml_object * const obj_new = (struct ggml_object *)(mem_buffer + cur_end);
|
||||
|
||||
if (ctx->scratch.data == NULL || data != NULL) {
|
||||
size_needed += GGML_TENSOR_SIZE;
|
||||
|
||||
if (cur_end + size_needed + GGML_OBJECT_SIZE > ctx->mem_size) {
|
||||
GGML_PRINT("%s: not enough space in the context's memory pool (needed %zu, available %zu)\n",
|
||||
__func__, cur_end + size_needed + GGML_OBJECT_SIZE, ctx->mem_size);
|
||||
assert(false);
|
||||
return NULL;
|
||||
}
|
||||
|
||||
*obj_new = (struct ggml_object) {
|
||||
.offs = cur_end + GGML_OBJECT_SIZE,
|
||||
.size = size_needed,
|
||||
.next = NULL,
|
||||
};
|
||||
} else {
|
||||
if (ctx->scratch.offs + size_needed > ctx->scratch.size) {
|
||||
GGML_PRINT("%s: not enough space in the scratch memory pool (needed %zu, available %zu)\n",
|
||||
__func__, ctx->scratch.offs + size_needed, ctx->scratch.size);
|
||||
assert(false);
|
||||
return NULL;
|
||||
}
|
||||
|
||||
if (cur_end + GGML_TENSOR_SIZE + GGML_OBJECT_SIZE > ctx->mem_size) {
|
||||
GGML_PRINT("%s: not enough space in the context's memory pool (needed %zu, available %zu)\n",
|
||||
__func__, cur_end + GGML_TENSOR_SIZE + GGML_OBJECT_SIZE, ctx->mem_size);
|
||||
assert(false);
|
||||
return NULL;
|
||||
}
|
||||
|
||||
data = (char * const) ctx->scratch.data + ctx->scratch.offs;
|
||||
|
||||
*obj_new = (struct ggml_object) {
|
||||
.offs = cur_end + GGML_OBJECT_SIZE,
|
||||
.size = GGML_TENSOR_SIZE,
|
||||
.next = NULL,
|
||||
};
|
||||
|
||||
//printf("scratch offs = %zu, size_needed = %zu\n", ctx->scratch.offs, size_needed);
|
||||
|
||||
ctx->scratch.offs += size_needed;
|
||||
if (cur_end + size_needed + GGML_OBJECT_SIZE > ctx->mem_size) {
|
||||
GGML_PRINT("%s: not enough space in the context's memory pool (needed %zu, available %zu)\n",
|
||||
__func__, cur_end + size_needed, ctx->mem_size);
|
||||
assert(false);
|
||||
return NULL;
|
||||
}
|
||||
|
||||
*obj_new = (struct ggml_object) {
|
||||
.offs = cur_end + GGML_OBJECT_SIZE,
|
||||
.size = size_needed,
|
||||
.next = NULL,
|
||||
.type = type,
|
||||
};
|
||||
|
||||
ggml_assert_aligned(mem_buffer + obj_new->offs);
|
||||
|
||||
if (obj_cur != NULL) {
|
||||
obj_cur->next = obj_new;
|
||||
} else {
|
||||
|
@ -4591,9 +4553,46 @@ static struct ggml_tensor * ggml_new_tensor_impl(
|
|||
|
||||
//printf("%s: inserted new object at %zu, size = %zu\n", __func__, cur_end, obj_new->size);
|
||||
|
||||
struct ggml_tensor * const result = (struct ggml_tensor *)(mem_buffer + obj_new->offs);
|
||||
return obj_new;
|
||||
}
|
||||
|
||||
ggml_assert_aligned(result);
|
||||
static struct ggml_tensor * ggml_new_tensor_impl(
|
||||
struct ggml_context * ctx,
|
||||
enum ggml_type type,
|
||||
int n_dims,
|
||||
const int64_t* ne,
|
||||
void* data) {
|
||||
|
||||
size_t data_size = 0;
|
||||
|
||||
if (data == NULL && !ctx->no_alloc) {
|
||||
data_size += GGML_TYPE_SIZE[type]*(ne[0]/GGML_BLCK_SIZE[type]);
|
||||
for (int i = 1; i < n_dims; i++) {
|
||||
data_size *= ne[i];
|
||||
}
|
||||
}
|
||||
|
||||
if (ctx->scratch.data != NULL && data == NULL) {
|
||||
// allocate tensor data in the scratch buffer
|
||||
if (ctx->scratch.offs + data_size > ctx->scratch.size) {
|
||||
GGML_PRINT("%s: not enough space in the scratch memory pool (needed %zu, available %zu)\n",
|
||||
__func__, ctx->scratch.offs + data_size, ctx->scratch.size);
|
||||
assert(false);
|
||||
return NULL;
|
||||
}
|
||||
|
||||
data = (char * const) ctx->scratch.data + ctx->scratch.offs;
|
||||
|
||||
ctx->scratch.offs += data_size;
|
||||
|
||||
data_size = 0;
|
||||
}
|
||||
|
||||
struct ggml_object * const obj_new = ggml_new_object(ctx, GGML_OBJECT_TENSOR, GGML_TENSOR_SIZE + data_size);
|
||||
|
||||
// TODO: for recoverable errors, we would need to free the data allocated from the scratch buffer here
|
||||
|
||||
struct ggml_tensor * const result = (struct ggml_tensor *)((char *)ctx->mem_buffer + obj_new->offs);
|
||||
|
||||
*result = (struct ggml_tensor) {
|
||||
/*.type =*/ type,
|
||||
|
@ -4984,11 +4983,6 @@ enum ggml_unary_op ggml_get_unary_op(const struct ggml_tensor * tensor) {
|
|||
return (enum ggml_unary_op) ggml_get_op_params_i32(tensor, 0);
|
||||
}
|
||||
|
||||
static void ggml_set_unary_op(struct ggml_tensor * tensor, enum ggml_unary_op op) {
|
||||
GGML_ASSERT(tensor->op = GGML_OP_UNARY);
|
||||
ggml_set_op_params_i32(tensor, 0, (int32_t) op);
|
||||
}
|
||||
|
||||
const char * ggml_get_name(const struct ggml_tensor * tensor) {
|
||||
return tensor->name;
|
||||
}
|
||||
|
@ -5027,9 +5021,11 @@ struct ggml_tensor * ggml_get_tensor(struct ggml_context * ctx, const char * nam
|
|||
char * const mem_buffer = ctx->mem_buffer;
|
||||
|
||||
while (obj != NULL) {
|
||||
struct ggml_tensor * cur = (struct ggml_tensor *)(mem_buffer + obj->offs);
|
||||
if (strcmp(cur->name, name) == 0) {
|
||||
return cur;
|
||||
if (obj->type == GGML_OBJECT_TENSOR) {
|
||||
struct ggml_tensor * cur = (struct ggml_tensor *)(mem_buffer + obj->offs);
|
||||
if (strcmp(cur->name, name) == 0) {
|
||||
return cur;
|
||||
}
|
||||
}
|
||||
|
||||
obj = obj->next;
|
||||
|
@ -7226,7 +7222,7 @@ static struct ggml_tensor * ggml_unary_impl(
|
|||
|
||||
struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
|
||||
|
||||
ggml_set_unary_op(result, op);
|
||||
ggml_set_op_params_i32(result, 0, (int32_t) op);
|
||||
|
||||
result->op = GGML_OP_UNARY;
|
||||
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
|
||||
|
@ -15825,6 +15821,35 @@ struct ggml_cgraph ggml_build_backward(struct ggml_context * ctx, struct ggml_cg
|
|||
return result;
|
||||
}
|
||||
|
||||
struct ggml_cgraph * ggml_new_graph(struct ggml_context * ctx) {
|
||||
struct ggml_object * obj = ggml_new_object(ctx, GGML_OBJECT_GRAPH, GGML_GRAPH_SIZE);
|
||||
struct ggml_cgraph * cgraph = (struct ggml_cgraph *) ((char *) ctx->mem_buffer + obj->offs);
|
||||
|
||||
*cgraph = (struct ggml_cgraph) {
|
||||
/*.n_nodes =*/ 0,
|
||||
/*.n_leafs =*/ 0,
|
||||
/*.nodes =*/ { NULL },
|
||||
/*.grads =*/ { NULL },
|
||||
/*.leafs =*/ { NULL },
|
||||
/*.hash_table =*/ { NULL },
|
||||
/*.perf_runs =*/ 0,
|
||||
/*.perf_cycles =*/ 0,
|
||||
/*.perf_time_us =*/ 0,
|
||||
};
|
||||
|
||||
return cgraph;
|
||||
}
|
||||
|
||||
struct ggml_cgraph * ggml_build_forward_ctx(struct ggml_context * ctx, struct ggml_tensor * tensor) {
|
||||
struct ggml_cgraph * cgraph = ggml_new_graph(ctx);
|
||||
ggml_build_forward_impl(cgraph, tensor, false);
|
||||
return cgraph;
|
||||
}
|
||||
|
||||
size_t ggml_graph_overhead(void) {
|
||||
return GGML_OBJECT_SIZE + GGML_PAD(GGML_GRAPH_SIZE, GGML_MEM_ALIGN);
|
||||
}
|
||||
|
||||
//
|
||||
// thread data
|
||||
//
|
||||
|
@ -16544,10 +16569,9 @@ void ggml_graph_reset(struct ggml_cgraph * cgraph) {
|
|||
void ggml_graph_compute_with_ctx(struct ggml_context * ctx, struct ggml_cgraph * cgraph, int n_threads) {
|
||||
struct ggml_cplan cplan = ggml_graph_plan(cgraph, n_threads);
|
||||
|
||||
struct ggml_tensor * buf = ggml_new_tensor_1d(ctx, GGML_TYPE_I8, cplan.work_size);
|
||||
GGML_ASSERT(buf);
|
||||
struct ggml_object * obj = ggml_new_object(ctx, GGML_OBJECT_WORK_BUFFER, cplan.work_size);
|
||||
|
||||
cplan.work_data = buf->data;
|
||||
cplan.work_data = (uint8_t *)ctx->mem_buffer + obj->offs;
|
||||
|
||||
ggml_graph_compute(cgraph, &cplan);
|
||||
}
|
||||
|
|
22
ggml.h
22
ggml.h
|
@ -208,6 +208,8 @@
|
|||
|
||||
#define GGML_UNUSED(x) (void)(x)
|
||||
|
||||
#define GGML_PAD(x, n) (((x) + (n) - 1) & ~((n) - 1))
|
||||
|
||||
#define GGML_ASSERT(x) \
|
||||
do { \
|
||||
if (!(x)) { \
|
||||
|
@ -395,6 +397,12 @@ extern "C" {
|
|||
GGML_UNARY_OP_SILU,
|
||||
};
|
||||
|
||||
enum ggml_object_type {
|
||||
GGML_OBJECT_TENSOR,
|
||||
GGML_OBJECT_GRAPH,
|
||||
GGML_OBJECT_WORK_BUFFER
|
||||
};
|
||||
|
||||
// ggml object
|
||||
struct ggml_object {
|
||||
size_t offs;
|
||||
|
@ -402,7 +410,9 @@ extern "C" {
|
|||
|
||||
struct ggml_object * next;
|
||||
|
||||
char padding[8];
|
||||
enum ggml_object_type type;
|
||||
|
||||
char padding[4];
|
||||
};
|
||||
|
||||
static const size_t GGML_OBJECT_SIZE = sizeof(struct ggml_object);
|
||||
|
@ -423,7 +433,7 @@ extern "C" {
|
|||
enum ggml_op op;
|
||||
|
||||
// op params - allocated as int32_t for alignment
|
||||
int32_t op_params[GGML_MAX_OP_PARAMS / sizeof(uint32_t)];
|
||||
int32_t op_params[GGML_MAX_OP_PARAMS / sizeof(int32_t)];
|
||||
|
||||
bool is_param;
|
||||
|
||||
|
@ -484,6 +494,8 @@ extern "C" {
|
|||
int64_t perf_time_us;
|
||||
};
|
||||
|
||||
static const size_t GGML_GRAPH_SIZE = sizeof(struct ggml_cgraph);
|
||||
|
||||
// scratch buffer
|
||||
struct ggml_scratch {
|
||||
size_t offs;
|
||||
|
@ -1390,11 +1402,17 @@ extern "C" {
|
|||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * tensor);
|
||||
|
||||
|
||||
GGML_API void ggml_build_forward_expand(struct ggml_cgraph * cgraph, struct ggml_tensor * tensor);
|
||||
|
||||
GGML_API struct ggml_cgraph ggml_build_forward (struct ggml_tensor * tensor);
|
||||
GGML_API struct ggml_cgraph ggml_build_backward(struct ggml_context * ctx, struct ggml_cgraph * gf, bool keep);
|
||||
|
||||
// graph allocation in a context
|
||||
GGML_API struct ggml_cgraph * ggml_new_graph (struct ggml_context * ctx);
|
||||
GGML_API struct ggml_cgraph * ggml_build_forward_ctx(struct ggml_context * ctx, struct ggml_tensor * tensor);
|
||||
GGML_API size_t ggml_graph_overhead(void);
|
||||
|
||||
// ggml_graph_plan() has to be called before ggml_graph_compute()
|
||||
// when plan.work_size > 0, caller must allocate memory for plan.work_data
|
||||
GGML_API struct ggml_cplan ggml_graph_plan (struct ggml_cgraph * cgraph, int n_threads /*= GGML_DEFAULT_N_THREADS*/);
|
||||
|
|
62
k_quants.c
62
k_quants.c
|
@ -39,6 +39,8 @@
|
|||
#define MIN(a, b) ((a) < (b) ? (a) : (b))
|
||||
#define MAX(a, b) ((a) > (b) ? (a) : (b))
|
||||
|
||||
#define MM256_SET_M128I(a, b) _mm256_insertf128_si256(_mm256_castsi128_si256(b), (a), 1)
|
||||
|
||||
//
|
||||
// 2-6 bit quantization in super-blocks
|
||||
//
|
||||
|
@ -1353,7 +1355,7 @@ void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restri
|
|||
const __m256i all_scales = _mm256_cvtepi8_epi16(scales8);
|
||||
const __m128i l_scales = _mm256_extracti128_si256(all_scales, 0);
|
||||
const __m128i h_scales = _mm256_extracti128_si256(all_scales, 1);
|
||||
const __m256i scales[2] = {_mm256_set_m128i(l_scales, l_scales), _mm256_set_m128i(h_scales, h_scales)};
|
||||
const __m256i scales[2] = {MM256_SET_M128I(l_scales, l_scales), MM256_SET_M128I(h_scales, h_scales)};
|
||||
|
||||
__m256i sumi = _mm256_setzero_si256();
|
||||
|
||||
|
@ -1421,7 +1423,7 @@ void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restri
|
|||
const __m128i summs_1 = _mm_madd_epi16(mins_1, _mm_loadu_si128((const __m128i*)&y[i].bsums[8]));
|
||||
|
||||
// sumf += -dmin * summs in 32bits*8
|
||||
acc = _mm256_add_ps(_mm256_mul_ps(_mm256_broadcast_ss(&dmin), _mm256_cvtepi32_ps(_mm256_set_m128i(summs_1, summs_0))), acc);
|
||||
acc = _mm256_add_ps(_mm256_mul_ps(_mm256_broadcast_ss(&dmin), _mm256_cvtepi32_ps(MM256_SET_M128I(summs_1, summs_0))), acc);
|
||||
|
||||
const __m128i scales_0 = _mm_cvtepi8_epi16(scales16);
|
||||
const __m128i scales_1 = _mm_cvtepi8_epi16(_mm_unpackhi_epi64(scales16, scales16));
|
||||
|
@ -1493,7 +1495,7 @@ void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restri
|
|||
}
|
||||
|
||||
// sumf += dall * isum - dmin * summs in 32bits
|
||||
__m256i sumi = _mm256_set_m128i(sumi_1, sumi_0);
|
||||
__m256i sumi = MM256_SET_M128I(sumi_1, sumi_0);
|
||||
acc = _mm256_add_ps(_mm256_mul_ps(_mm256_broadcast_ss(&dall), _mm256_cvtepi32_ps(sumi)), acc);
|
||||
}
|
||||
|
||||
|
@ -1644,8 +1646,8 @@ void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restri
|
|||
summs += dmin * smin;
|
||||
|
||||
const __m128i q2bits = _mm_loadu_si128((const __m128i*)q2);
|
||||
const __m256i q2_0 = _mm256_and_si256(_mm256_set_m128i(_mm_srli_epi16(q2bits, 2), q2bits), m3);
|
||||
const __m256i q2_1 = _mm256_and_si256(_mm256_set_m128i(_mm_srli_epi16(q2bits, 6), _mm_srli_epi16(q2bits, 4)), m3);
|
||||
const __m256i q2_0 = _mm256_and_si256(MM256_SET_M128I(_mm_srli_epi16(q2bits, 2), q2bits), m3);
|
||||
const __m256i q2_1 = _mm256_and_si256(MM256_SET_M128I(_mm_srli_epi16(q2bits, 6), _mm_srli_epi16(q2bits, 4)), m3);
|
||||
|
||||
const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)(q8+ 0));
|
||||
const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)(q8+32));
|
||||
|
@ -1709,10 +1711,10 @@ void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restri
|
|||
const __m128i p2 = _mm_maddubs_epi16(q2_2, _mm256_extractf128_si256(q8_1, 0));
|
||||
const __m128i p3 = _mm_maddubs_epi16(q2_3, _mm256_extractf128_si256(q8_1, 1));
|
||||
|
||||
const __m256i p_0 = _mm256_set_m128i(_mm_cvtepi16_epi32(_mm_unpackhi_epi64(p0, p0)), _mm_cvtepi16_epi32(p0));
|
||||
const __m256i p_1 = _mm256_set_m128i(_mm_cvtepi16_epi32(_mm_unpackhi_epi64(p1, p1)), _mm_cvtepi16_epi32(p1));
|
||||
const __m256i p_2 = _mm256_set_m128i(_mm_cvtepi16_epi32(_mm_unpackhi_epi64(p2, p2)), _mm_cvtepi16_epi32(p2));
|
||||
const __m256i p_3 = _mm256_set_m128i(_mm_cvtepi16_epi32(_mm_unpackhi_epi64(p3, p3)), _mm_cvtepi16_epi32(p3));
|
||||
const __m256i p_0 = MM256_SET_M128I(_mm_cvtepi16_epi32(_mm_unpackhi_epi64(p0, p0)), _mm_cvtepi16_epi32(p0));
|
||||
const __m256i p_1 = MM256_SET_M128I(_mm_cvtepi16_epi32(_mm_unpackhi_epi64(p1, p1)), _mm_cvtepi16_epi32(p1));
|
||||
const __m256i p_2 = MM256_SET_M128I(_mm_cvtepi16_epi32(_mm_unpackhi_epi64(p2, p2)), _mm_cvtepi16_epi32(p2));
|
||||
const __m256i p_3 = MM256_SET_M128I(_mm_cvtepi16_epi32(_mm_unpackhi_epi64(p3, p3)), _mm_cvtepi16_epi32(p3));
|
||||
|
||||
acc = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d * db[0]), _mm256_cvtepi32_ps(p_0)), acc);
|
||||
acc = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d * db[1]), _mm256_cvtepi32_ps(p_1)), acc);
|
||||
|
@ -1917,7 +1919,7 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri
|
|||
const __m256i all_scales = _mm256_cvtepi8_epi16(scales128);
|
||||
const __m128i l_scales = _mm256_extracti128_si256(all_scales, 0);
|
||||
const __m128i h_scales = _mm256_extracti128_si256(all_scales, 1);
|
||||
const __m256i scales[2] = {_mm256_set_m128i(l_scales, l_scales), _mm256_set_m128i(h_scales, h_scales)};
|
||||
const __m256i scales[2] = {MM256_SET_M128I(l_scales, l_scales), MM256_SET_M128I(h_scales, h_scales)};
|
||||
|
||||
// high bit
|
||||
const __m256i hbits = _mm256_loadu_si256((const __m256i*)x[i].hmask);
|
||||
|
@ -2128,7 +2130,7 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri
|
|||
}
|
||||
|
||||
// multiply with block scale and accumulate
|
||||
__m256i sumi = _mm256_set_m128i(sumi_1, sumi_0);
|
||||
__m256i sumi = MM256_SET_M128I(sumi_1, sumi_0);
|
||||
acc = _mm256_add_ps(_mm256_mul_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(sumi)), acc);
|
||||
|
||||
}
|
||||
|
@ -2303,13 +2305,13 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri
|
|||
aux16[0] = a & 0x0f0f;
|
||||
aux16[1] = (a >> 4) & 0x0f0f;
|
||||
|
||||
const __m256i scale_0 = _mm256_set_m128i(_mm_set1_epi16(aux8[2] - 8), _mm_set1_epi16(aux8[0] - 8));
|
||||
const __m256i scale_1 = _mm256_set_m128i(_mm_set1_epi16(aux8[3] - 8), _mm_set1_epi16(aux8[1] - 8));
|
||||
const __m256i scale_0 = MM256_SET_M128I(_mm_set1_epi16(aux8[2] - 8), _mm_set1_epi16(aux8[0] - 8));
|
||||
const __m256i scale_1 = MM256_SET_M128I(_mm_set1_epi16(aux8[3] - 8), _mm_set1_epi16(aux8[1] - 8));
|
||||
|
||||
memcpy(&aux64, x[i].hmask, 8);
|
||||
|
||||
const __m128i haux = _mm_set_epi64x(aux64 >> 1, aux64 >> 0);
|
||||
__m256i q3h_0 = _mm256_set_m128i(_mm_srli_epi16(haux, 2), haux);
|
||||
__m256i q3h_0 = MM256_SET_M128I(_mm_srli_epi16(haux, 2), haux);
|
||||
__m256i q3h_1 = _mm256_srli_epi16(q3h_0, 4);
|
||||
q3h_0 = _mm256_slli_epi16(_mm256_andnot_si256(q3h_0, m1), 2);
|
||||
q3h_1 = _mm256_slli_epi16(_mm256_andnot_si256(q3h_1, m1), 2);
|
||||
|
@ -2318,7 +2320,7 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri
|
|||
const __m128i q3bits = _mm_loadu_si128((const __m128i*)q3);
|
||||
|
||||
// prepare low and high bits
|
||||
const __m256i q3aux = _mm256_set_m128i(_mm_srli_epi16(q3bits, 2), q3bits);
|
||||
const __m256i q3aux = MM256_SET_M128I(_mm_srli_epi16(q3bits, 2), q3bits);
|
||||
const __m256i q3l_0 = _mm256_and_si256(q3aux, m3);
|
||||
const __m256i q3l_1 = _mm256_and_si256(_mm256_srli_epi16(q3aux, 4), m3);
|
||||
|
||||
|
@ -2429,7 +2431,7 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri
|
|||
|
||||
p16_0 = _mm_add_epi32(p16_0, p16_2);
|
||||
p16_1 = _mm_add_epi32(p16_1, p16_3);
|
||||
__m256i p16 = _mm256_set_m128i(p16_1, p16_0);
|
||||
__m256i p16 = MM256_SET_M128I(p16_1, p16_0);
|
||||
|
||||
// multiply with block scale and accumulate
|
||||
acc = _mm256_add_ps(_mm256_mul_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(p16)), acc);
|
||||
|
@ -2620,7 +2622,7 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri
|
|||
acc_m = _mm_fmadd_ps(_mm_set1_ps(dmin), _mm_cvtepi32_ps(prod), acc_m);
|
||||
|
||||
const __m128i sc128 = _mm256_extracti128_si256(mins_and_scales, 0);
|
||||
const __m256i scales = _mm256_set_m128i(sc128, sc128);
|
||||
const __m256i scales = MM256_SET_M128I(sc128, sc128);
|
||||
|
||||
__m256i sumi = _mm256_setzero_si256();
|
||||
|
||||
|
@ -2727,7 +2729,7 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri
|
|||
}
|
||||
|
||||
__m256 vd = _mm256_set1_ps(d);
|
||||
__m256i sumi = _mm256_set_m128i(sumi_1, sumi_0);
|
||||
__m256i sumi = MM256_SET_M128I(sumi_1, sumi_0);
|
||||
acc = _mm256_add_ps(_mm256_mul_ps(vd, _mm256_cvtepi32_ps(sumi)), acc);
|
||||
|
||||
}
|
||||
|
@ -2968,11 +2970,11 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri
|
|||
|
||||
const __m128i p32_0 = _mm_madd_epi16(_mm_set1_epi16(scales[0]), p16_0);
|
||||
const __m128i p32_1 = _mm_madd_epi16(_mm_set1_epi16(scales[0]), p16_1);
|
||||
acc = _mm256_add_ps(_mm256_mul_ps(vd, _mm256_cvtepi32_ps(_mm256_set_m128i(p32_1, p32_0))), acc);
|
||||
acc = _mm256_add_ps(_mm256_mul_ps(vd, _mm256_cvtepi32_ps(MM256_SET_M128I(p32_1, p32_0))), acc);
|
||||
|
||||
const __m128i p32_2 = _mm_madd_epi16(_mm_set1_epi16(scales[1]), p16_2);
|
||||
const __m128i p32_3 = _mm_madd_epi16(_mm_set1_epi16(scales[1]), p16_3);
|
||||
acc = _mm256_add_ps(_mm256_mul_ps(vd, _mm256_cvtepi32_ps(_mm256_set_m128i(p32_3, p32_2))), acc);
|
||||
acc = _mm256_add_ps(_mm256_mul_ps(vd, _mm256_cvtepi32_ps(MM256_SET_M128I(p32_3, p32_2))), acc);
|
||||
|
||||
}
|
||||
|
||||
|
@ -3160,7 +3162,7 @@ void ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * restri
|
|||
summs += dmin * _mm_extract_epi32(hsum, 0);
|
||||
|
||||
const __m128i sc128 = _mm256_extracti128_si256(mins_and_scales, 0);
|
||||
const __m256i scales = _mm256_set_m128i(sc128, sc128);
|
||||
const __m256i scales = MM256_SET_M128I(sc128, sc128);
|
||||
|
||||
const __m256i hbits = _mm256_loadu_si256((const __m256i*)x[i].qh);
|
||||
__m256i hmask = mone;
|
||||
|
@ -3299,7 +3301,7 @@ void ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * restri
|
|||
}
|
||||
|
||||
__m256 vd = _mm256_set1_ps(d);
|
||||
__m256i sumi = _mm256_set_m128i(sumi_1, sumi_0);
|
||||
__m256i sumi = MM256_SET_M128I(sumi_1, sumi_0);
|
||||
acc = _mm256_add_ps(_mm256_mul_ps(vd, _mm256_cvtepi32_ps(sumi)), acc);
|
||||
|
||||
}
|
||||
|
@ -3462,13 +3464,13 @@ void ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * restri
|
|||
|
||||
const __m256i q5bits = _mm256_loadu_si256((const __m256i*)q5);
|
||||
|
||||
const __m256i scale_l = _mm256_set_m128i(_mm_set1_epi16(x[i].scales[1]), _mm_set1_epi16(x[i].scales[0]));
|
||||
const __m256i scale_h = _mm256_set_m128i(_mm_set1_epi16(x[i].scales[3]), _mm_set1_epi16(x[i].scales[2]));
|
||||
const __m256i scale_l = MM256_SET_M128I(_mm_set1_epi16(x[i].scales[1]), _mm_set1_epi16(x[i].scales[0]));
|
||||
const __m256i scale_h = MM256_SET_M128I(_mm_set1_epi16(x[i].scales[3]), _mm_set1_epi16(x[i].scales[2]));
|
||||
|
||||
int64_t aux64;
|
||||
memcpy(&aux64, x[i].qh, 8);
|
||||
const __m128i haux128 = _mm_set_epi64x(aux64 >> 1, aux64);
|
||||
const __m256i haux256 = _mm256_set_m128i(_mm_srli_epi16(haux128, 2), haux128);
|
||||
const __m256i haux256 = MM256_SET_M128I(_mm_srli_epi16(haux128, 2), haux128);
|
||||
|
||||
const __m256i q5h_0 = _mm256_slli_epi16(_mm256_andnot_si256(haux256, mone), 4);
|
||||
const __m256i q5h_1 = _mm256_slli_epi16(_mm256_andnot_si256(_mm256_srli_epi16(haux256, 4), mone), 4);
|
||||
|
@ -3543,7 +3545,7 @@ void ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * restri
|
|||
const __m128i dot_0 = _mm_sub_epi32(_mm_add_epi32(p16_0, p16_2), _mm_add_epi32(s16_0, s16_2));
|
||||
const __m128i dot_1 = _mm_sub_epi32(_mm_add_epi32(p16_1, p16_3), _mm_add_epi32(s16_1, s16_3));
|
||||
|
||||
acc = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(_mm256_set_m128i(dot_1, dot_0))), acc);
|
||||
acc = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(MM256_SET_M128I(dot_1, dot_0))), acc);
|
||||
|
||||
}
|
||||
|
||||
|
@ -3925,7 +3927,7 @@ void ggml_vec_dot_q6_K_q8_K(const int n, float * restrict s, const void * restri
|
|||
|
||||
}
|
||||
|
||||
__m256i sumi = _mm256_set_m128i(sumi_1, sumi_0);
|
||||
__m256i sumi = MM256_SET_M128I(sumi_1, sumi_0);
|
||||
acc = _mm256_add_ps(_mm256_mul_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(sumi)), acc);
|
||||
}
|
||||
|
||||
|
@ -4083,8 +4085,8 @@ void ggml_vec_dot_q6_K_q8_K(const int n, float * restrict s, const void * restri
|
|||
const __m256i q4bits1 = _mm256_loadu_si256((const __m256i*)q4);
|
||||
const __m128i q4bitsH = _mm_loadu_si128((const __m128i*)qh);
|
||||
|
||||
const __m256i q4h_0 = _mm256_slli_epi16(_mm256_and_si256(_mm256_set_m128i(_mm_srli_epi16(q4bitsH, 2), q4bitsH), m2), 4);
|
||||
const __m256i q4h_1 = _mm256_slli_epi16(_mm256_and_si256(_mm256_set_m128i(_mm_srli_epi16(q4bitsH, 6), _mm_srli_epi16(q4bitsH, 4)), m2), 4);
|
||||
const __m256i q4h_0 = _mm256_slli_epi16(_mm256_and_si256(MM256_SET_M128I(_mm_srli_epi16(q4bitsH, 2), q4bitsH), m2), 4);
|
||||
const __m256i q4h_1 = _mm256_slli_epi16(_mm256_and_si256(MM256_SET_M128I(_mm_srli_epi16(q4bitsH, 6), _mm_srli_epi16(q4bitsH, 4)), m2), 4);
|
||||
|
||||
const __m256i q4_0 = _mm256_or_si256(_mm256_and_si256(q4bits1, m4), q4h_0);
|
||||
const __m256i q4_1 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(q4bits1, 4), m4), q4h_1);
|
||||
|
@ -4177,7 +4179,7 @@ void ggml_vec_dot_q6_K_q8_K(const int n, float * restrict s, const void * restri
|
|||
sumi_0 = _mm_add_epi32(sumi_0, _mm_add_epi32(p16_0, p16_2));
|
||||
sumi_1 = _mm_add_epi32(sumi_1, _mm_add_epi32(p16_1, p16_3));
|
||||
|
||||
acc = _mm256_add_ps(_mm256_mul_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(_mm256_set_m128i(sumi_1, sumi_0))), acc);
|
||||
acc = _mm256_add_ps(_mm256_mul_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(MM256_SET_M128I(sumi_1, sumi_0))), acc);
|
||||
}
|
||||
|
||||
*s = hsum_float_8(acc);
|
||||
|
|
43
llama.cpp
43
llama.cpp
|
@ -1431,7 +1431,7 @@ static bool llama_eval_internal(
|
|||
|
||||
struct ggml_context * ctx0 = ggml_init(params);
|
||||
|
||||
ggml_cgraph gf = {};
|
||||
ggml_cgraph * gf = ggml_new_graph(ctx0);
|
||||
|
||||
// for big prompts, if BLAS is enabled, it is better to use only one thread
|
||||
// otherwise, the threads are spin-lock waiting for the BLAS calls and are degrading the performance
|
||||
|
@ -1548,8 +1548,8 @@ static bool llama_eval_internal(
|
|||
ggml_set_name(v, "v");
|
||||
|
||||
// important: storing RoPE-ed version of K in the KV cache!
|
||||
ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Kcur, k));
|
||||
ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Vcur, v));
|
||||
ggml_build_forward_expand(gf, ggml_cpy(ctx0, Kcur, k));
|
||||
ggml_build_forward_expand(gf, ggml_cpy(ctx0, Vcur, v));
|
||||
}
|
||||
|
||||
struct ggml_tensor * Q =
|
||||
|
@ -1719,21 +1719,22 @@ static bool llama_eval_internal(
|
|||
//cur = ggml_soft_max_inplace(ctx0, cur);
|
||||
|
||||
// run the computation
|
||||
ggml_build_forward_expand(&gf, cur);
|
||||
ggml_build_forward_expand(gf, cur);
|
||||
|
||||
// fprintf(stderr, "graph build time: %.3f ms (%d nodes, %d leafs)\n", (ggml_time_us() - t_start_us)/1000.0, gf.n_nodes, gf.n_leafs);
|
||||
|
||||
#if GGML_USE_MPI
|
||||
ggml_mpi_graph_compute_pre(lctx.ctx_mpi, &gf, n_layer);
|
||||
ggml_mpi_graph_compute_pre(lctx.ctx_mpi, gf, n_layer);
|
||||
#endif
|
||||
|
||||
#ifdef GGML_USE_METAL
|
||||
if (lctx.ctx_metal && N == 1) {
|
||||
if (!ggml_metal_if_optimized(lctx.ctx_metal)) {
|
||||
ggml_metal_graph_find_concurrency(lctx.ctx_metal,&gf);
|
||||
}
|
||||
// TODO: disabled until #2413 is resolved
|
||||
//if (!ggml_metal_if_optimized(lctx.ctx_metal)) {
|
||||
// ggml_metal_graph_find_concurrency(lctx.ctx_metal, gf);
|
||||
//}
|
||||
ggml_metal_set_n_cb (lctx.ctx_metal, n_threads);
|
||||
ggml_metal_graph_compute(lctx.ctx_metal, &gf);
|
||||
ggml_metal_graph_compute(lctx.ctx_metal, gf);
|
||||
ggml_metal_get_tensor (lctx.ctx_metal, cur);
|
||||
} else {
|
||||
// IMPORTANT:
|
||||
|
@ -1752,34 +1753,34 @@ static bool llama_eval_internal(
|
|||
ggml_metal_get_tensor(lctx.ctx_metal, kv_self.v);
|
||||
}
|
||||
|
||||
ggml_graph_compute_helper(lctx.work_buffer, &gf, n_threads);
|
||||
ggml_graph_compute_helper(lctx.work_buffer, gf, n_threads);
|
||||
}
|
||||
#else
|
||||
ggml_graph_compute_helper(lctx.work_buffer, &gf, n_threads);
|
||||
ggml_graph_compute_helper(lctx.work_buffer, gf, n_threads);
|
||||
#endif
|
||||
|
||||
#if GGML_USE_MPI
|
||||
ggml_mpi_graph_compute_post(lctx.ctx_mpi, &gf, n_layer);
|
||||
ggml_mpi_graph_compute_post(lctx.ctx_mpi, gf, n_layer);
|
||||
#endif
|
||||
|
||||
// update kv token count
|
||||
lctx.kv_self.n = n_past + N;
|
||||
|
||||
struct ggml_tensor * res = gf.nodes[gf.n_nodes - 1];
|
||||
struct ggml_tensor * res = gf->nodes[gf->n_nodes - 1];
|
||||
|
||||
if (cgraph_fname) {
|
||||
ggml_graph_export(&gf, cgraph_fname);
|
||||
ggml_graph_export(gf, cgraph_fname);
|
||||
}
|
||||
|
||||
#ifdef GGML_PERF
|
||||
// print timing information per ggml operation (for debugging purposes)
|
||||
// requires GGML_PERF to be defined
|
||||
ggml_graph_print(&gf);
|
||||
ggml_graph_print(gf);
|
||||
#endif
|
||||
|
||||
// plot the computation graph in dot format (for debugging purposes)
|
||||
//if (n_past%100 == 0) {
|
||||
// ggml_graph_dump_dot(&gf, NULL, "llama.dot");
|
||||
// ggml_graph_dump_dot(gf, NULL, "llama.dot");
|
||||
//}
|
||||
|
||||
// extract logits
|
||||
|
@ -1930,7 +1931,9 @@ struct llama_tokenizer {
|
|||
if (token == vocab_.token_to_id.end()) {
|
||||
// output any symbols that did not form tokens as bytes.
|
||||
for (int j = 0; j < (int) symbol.n; ++j) {
|
||||
llama_vocab::id token_id = static_cast<uint8_t>(symbol.text[j]) + 3;
|
||||
// NOTE: old version, before #2420 - not sure what are the implications of this
|
||||
//llama_vocab::id token_id = static_cast<uint8_t>(symbol.text[j]) + 3;
|
||||
llama_vocab::id token_id = vocab_.token_to_id.at(std::string(1, symbol.text[j]));
|
||||
output.push_back(token_id);
|
||||
}
|
||||
} else {
|
||||
|
@ -3186,7 +3189,7 @@ struct llama_context * llama_new_context_with_model(
|
|||
ctx->embedding.resize(hparams.n_embd);
|
||||
}
|
||||
|
||||
ctx->buf_compute.resize(blasbatchmul*MEM_REQ_EVAL().at(ctx->model.type));
|
||||
ctx->buf_compute.resize(blasbatchmul*MEM_REQ_EVAL().at(ctx->model.type) + ggml_graph_overhead());
|
||||
|
||||
ctx->buf_scratch[0].resize(blasbatchmul*MEM_REQ_SCRATCH0(hparams.n_ctx).at(ctx->model.type));
|
||||
ctx->buf_scratch[1].resize(blasbatchmul*MEM_REQ_SCRATCH1().at(ctx->model.type));
|
||||
|
@ -3671,7 +3674,7 @@ size_t llama_copy_state_data(struct llama_context * ctx, uint8_t * dst) {
|
|||
const auto & kv_self = ctx->kv_self;
|
||||
const auto & hparams = ctx->model.hparams;
|
||||
const int n_layer = hparams.n_layer;
|
||||
const int n_embd = hparams.n_embd;
|
||||
const int n_embd = hparams.n_embd_gqa();
|
||||
const int n_ctx = hparams.n_ctx;
|
||||
|
||||
const size_t kv_size = kv_self.buf.size;
|
||||
|
@ -3774,7 +3777,7 @@ size_t llama_set_state_data(struct llama_context * ctx, uint8_t * src) {
|
|||
const auto & kv_self = ctx->kv_self;
|
||||
const auto & hparams = ctx->model.hparams;
|
||||
const int n_layer = hparams.n_layer;
|
||||
const int n_embd = hparams.n_embd;
|
||||
const int n_embd = hparams.n_embd_gqa();
|
||||
const int n_ctx = hparams.n_ctx;
|
||||
|
||||
size_t kv_size;
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue