llama : add llama_sampling and combine it with llama_grammar
ggml-ci
This commit is contained in:
parent
1262e7ed13
commit
cc53500f65
45 changed files with 1759 additions and 1701 deletions
6
Makefile
6
Makefile
|
@ -923,7 +923,6 @@ OBJ_COMMON = \
|
|||
common/ngram-cache.o \
|
||||
common/sampling.o \
|
||||
common/train.o \
|
||||
common/grammar-parser.o \
|
||||
common/build-info.o \
|
||||
common/json-schema-to-grammar.o
|
||||
|
||||
|
@ -1163,11 +1162,6 @@ common/console.o: \
|
|||
common/console.h
|
||||
$(CXX) $(CXXFLAGS) -c $< -o $@
|
||||
|
||||
common/grammar-parser.o: \
|
||||
common/grammar-parser.cpp \
|
||||
common/grammar-parser.h
|
||||
$(CXX) $(CXXFLAGS) -c $< -o $@
|
||||
|
||||
common/json-schema-to-grammar.o: \
|
||||
common/json-schema-to-grammar.cpp \
|
||||
common/json-schema-to-grammar.h
|
||||
|
|
|
@ -58,8 +58,6 @@ add_library(${TARGET} STATIC
|
|||
sampling.cpp
|
||||
console.h
|
||||
console.cpp
|
||||
grammar-parser.h
|
||||
grammar-parser.cpp
|
||||
json.hpp
|
||||
json-schema-to-grammar.cpp
|
||||
train.h
|
||||
|
|
|
@ -2161,7 +2161,7 @@ struct llama_init_result llama_init_from_gpt_params(gpt_params & params) {
|
|||
}
|
||||
llama_kv_cache_clear(lctx);
|
||||
llama_synchronize(lctx);
|
||||
llama_reset_timings(lctx);
|
||||
llama_reset_timings(lctx, nullptr);
|
||||
}
|
||||
|
||||
iparams.model = model;
|
||||
|
|
|
@ -1,539 +0,0 @@
|
|||
#include "grammar-parser.h"
|
||||
#include <cstdint>
|
||||
#include <cwchar>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <stdexcept>
|
||||
#include <exception>
|
||||
|
||||
namespace grammar_parser {
|
||||
// NOTE: assumes valid utf8 (but checks for overrun)
|
||||
// copied from llama.cpp
|
||||
static std::pair<uint32_t, const char *> decode_utf8(const char * src) {
|
||||
static const int lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 4 };
|
||||
uint8_t first_byte = static_cast<uint8_t>(*src);
|
||||
uint8_t highbits = first_byte >> 4;
|
||||
int len = lookup[highbits];
|
||||
uint8_t mask = (1 << (8 - len)) - 1;
|
||||
uint32_t value = first_byte & mask;
|
||||
const char * end = src + len; // may overrun!
|
||||
const char * pos = src + 1;
|
||||
for ( ; pos < end && *pos; pos++) {
|
||||
value = (value << 6) + (static_cast<uint8_t>(*pos) & 0x3F);
|
||||
}
|
||||
return std::make_pair(value, pos);
|
||||
}
|
||||
|
||||
static uint32_t get_symbol_id(parse_state & state, const char * src, size_t len) {
|
||||
uint32_t next_id = static_cast<uint32_t>(state.symbol_ids.size());
|
||||
auto result = state.symbol_ids.emplace(std::string(src, len), next_id);
|
||||
return result.first->second;
|
||||
}
|
||||
|
||||
static uint32_t generate_symbol_id(parse_state & state, const std::string & base_name) {
|
||||
uint32_t next_id = static_cast<uint32_t>(state.symbol_ids.size());
|
||||
state.symbol_ids[base_name + '_' + std::to_string(next_id)] = next_id;
|
||||
return next_id;
|
||||
}
|
||||
|
||||
static void add_rule(
|
||||
parse_state & state,
|
||||
uint32_t rule_id,
|
||||
const std::vector<llama_grammar_element> & rule) {
|
||||
if (state.rules.size() <= rule_id) {
|
||||
state.rules.resize(rule_id + 1);
|
||||
}
|
||||
state.rules[rule_id] = rule;
|
||||
}
|
||||
|
||||
static bool is_digit_char(char c) {
|
||||
return '0' <= c && c <= '9';
|
||||
}
|
||||
|
||||
static bool is_word_char(char c) {
|
||||
return ('a' <= c && c <= 'z') || ('A' <= c && c <= 'Z') || c == '-' || is_digit_char(c);
|
||||
}
|
||||
|
||||
static std::pair<uint32_t, const char *> parse_hex(const char * src, int size) {
|
||||
const char * pos = src;
|
||||
const char * end = src + size;
|
||||
uint32_t value = 0;
|
||||
for ( ; pos < end && *pos; pos++) {
|
||||
value <<= 4;
|
||||
char c = *pos;
|
||||
if ('a' <= c && c <= 'f') {
|
||||
value += c - 'a' + 10;
|
||||
} else if ('A' <= c && c <= 'F') {
|
||||
value += c - 'A' + 10;
|
||||
} else if ('0' <= c && c <= '9') {
|
||||
value += c - '0';
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (pos != end) {
|
||||
throw std::runtime_error("expecting " + std::to_string(size) + " hex chars at " + src);
|
||||
}
|
||||
return std::make_pair(value, pos);
|
||||
}
|
||||
|
||||
static const char * parse_space(const char * src, bool newline_ok) {
|
||||
const char * pos = src;
|
||||
while (*pos == ' ' || *pos == '\t' || *pos == '#' ||
|
||||
(newline_ok && (*pos == '\r' || *pos == '\n'))) {
|
||||
if (*pos == '#') {
|
||||
while (*pos && *pos != '\r' && *pos != '\n') {
|
||||
pos++;
|
||||
}
|
||||
} else {
|
||||
pos++;
|
||||
}
|
||||
}
|
||||
return pos;
|
||||
}
|
||||
|
||||
static const char * parse_name(const char * src) {
|
||||
const char * pos = src;
|
||||
while (is_word_char(*pos)) {
|
||||
pos++;
|
||||
}
|
||||
if (pos == src) {
|
||||
throw std::runtime_error(std::string("expecting name at ") + src);
|
||||
}
|
||||
return pos;
|
||||
}
|
||||
|
||||
static const char * parse_int(const char * src) {
|
||||
const char * pos = src;
|
||||
while (is_digit_char(*pos)) {
|
||||
pos++;
|
||||
}
|
||||
if (pos == src) {
|
||||
throw std::runtime_error(std::string("expecting integer at ") + src);
|
||||
}
|
||||
return pos;
|
||||
}
|
||||
|
||||
static std::pair<uint32_t, const char *> parse_char(const char * src) {
|
||||
if (*src == '\\') {
|
||||
switch (src[1]) {
|
||||
case 'x': return parse_hex(src + 2, 2);
|
||||
case 'u': return parse_hex(src + 2, 4);
|
||||
case 'U': return parse_hex(src + 2, 8);
|
||||
case 't': return std::make_pair('\t', src + 2);
|
||||
case 'r': return std::make_pair('\r', src + 2);
|
||||
case 'n': return std::make_pair('\n', src + 2);
|
||||
case '\\':
|
||||
case '"':
|
||||
case '[':
|
||||
case ']':
|
||||
return std::make_pair(src[1], src + 2);
|
||||
default:
|
||||
throw std::runtime_error(std::string("unknown escape at ") + src);
|
||||
}
|
||||
} else if (*src) {
|
||||
return decode_utf8(src);
|
||||
}
|
||||
throw std::runtime_error("unexpected end of input");
|
||||
}
|
||||
|
||||
const char * parse_alternates(
|
||||
parse_state & state,
|
||||
const char * src,
|
||||
const std::string & rule_name,
|
||||
uint32_t rule_id,
|
||||
bool is_nested);
|
||||
|
||||
static const char * parse_sequence(
|
||||
parse_state & state,
|
||||
const char * src,
|
||||
const std::string & rule_name,
|
||||
std::vector<llama_grammar_element> & out_elements,
|
||||
bool is_nested) {
|
||||
size_t last_sym_start = out_elements.size();
|
||||
const char * pos = src;
|
||||
|
||||
auto handle_repetitions = [&](int min_times, int max_times) {
|
||||
|
||||
if (last_sym_start == out_elements.size()) {
|
||||
throw std::runtime_error(std::string("expecting preceding item to */+/?/{ at ") + pos);
|
||||
}
|
||||
|
||||
// apply transformation to previous symbol (last_sym_start to end) according to
|
||||
// the following rewrite rules:
|
||||
// S{m,n} --> S S S (m times) S'(n-m)
|
||||
// S'(x) ::= S S'(x-1) |
|
||||
// (... n-m definitions of these S' rules ...)
|
||||
// S'(1) ::= S |
|
||||
// S{m,} --> S S S (m times) S'
|
||||
// S' ::= S S' |
|
||||
// S* --> S{0,}
|
||||
// --> S' ::= S S' |
|
||||
// S+ --> S{1,}
|
||||
// --> S S'
|
||||
// S' ::= S S' |
|
||||
// S? --> S{0,1}
|
||||
// --> S'
|
||||
// S' ::= S |
|
||||
|
||||
std::vector<llama_grammar_element> previous_elements(out_elements.begin() + last_sym_start, out_elements.end());
|
||||
if (min_times == 0) {
|
||||
out_elements.resize(last_sym_start);
|
||||
} else {
|
||||
// Repeat the previous elements (min_times - 1) times
|
||||
for (int i = 1; i < min_times; i++) {
|
||||
out_elements.insert(out_elements.end(), previous_elements.begin(), previous_elements.end());
|
||||
}
|
||||
}
|
||||
|
||||
uint32_t last_rec_rule_id = 0;
|
||||
auto n_opt = max_times < 0 ? 1 : max_times - min_times;
|
||||
|
||||
std::vector<llama_grammar_element> rec_rule(previous_elements);
|
||||
for (int i = 0; i < n_opt; i++) {
|
||||
rec_rule.resize(previous_elements.size());
|
||||
uint32_t rec_rule_id = generate_symbol_id(state, rule_name);
|
||||
if (i > 0 || max_times < 0) {
|
||||
rec_rule.push_back({LLAMA_GRETYPE_RULE_REF, max_times < 0 ? rec_rule_id : last_rec_rule_id});
|
||||
}
|
||||
rec_rule.push_back({LLAMA_GRETYPE_ALT, 0});
|
||||
rec_rule.push_back({LLAMA_GRETYPE_END, 0});
|
||||
add_rule(state, rec_rule_id, rec_rule);
|
||||
last_rec_rule_id = rec_rule_id;
|
||||
}
|
||||
if (n_opt > 0) {
|
||||
out_elements.push_back({LLAMA_GRETYPE_RULE_REF, last_rec_rule_id});
|
||||
}
|
||||
};
|
||||
|
||||
while (*pos) {
|
||||
if (*pos == '"') { // literal string
|
||||
pos++;
|
||||
last_sym_start = out_elements.size();
|
||||
while (*pos != '"') {
|
||||
if (!*pos) {
|
||||
throw std::runtime_error("unexpected end of input");
|
||||
}
|
||||
auto char_pair = parse_char(pos);
|
||||
pos = char_pair.second;
|
||||
out_elements.push_back({LLAMA_GRETYPE_CHAR, char_pair.first});
|
||||
}
|
||||
pos = parse_space(pos + 1, is_nested);
|
||||
} else if (*pos == '[') { // char range(s)
|
||||
pos++;
|
||||
enum llama_gretype start_type = LLAMA_GRETYPE_CHAR;
|
||||
if (*pos == '^') {
|
||||
pos++;
|
||||
start_type = LLAMA_GRETYPE_CHAR_NOT;
|
||||
}
|
||||
last_sym_start = out_elements.size();
|
||||
while (*pos != ']') {
|
||||
if (!*pos) {
|
||||
throw std::runtime_error("unexpected end of input");
|
||||
}
|
||||
auto char_pair = parse_char(pos);
|
||||
pos = char_pair.second;
|
||||
enum llama_gretype type = last_sym_start < out_elements.size()
|
||||
? LLAMA_GRETYPE_CHAR_ALT
|
||||
: start_type;
|
||||
|
||||
out_elements.push_back({type, char_pair.first});
|
||||
if (pos[0] == '-' && pos[1] != ']') {
|
||||
if (!pos[1]) {
|
||||
throw std::runtime_error("unexpected end of input");
|
||||
}
|
||||
auto endchar_pair = parse_char(pos + 1);
|
||||
pos = endchar_pair.second;
|
||||
out_elements.push_back({LLAMA_GRETYPE_CHAR_RNG_UPPER, endchar_pair.first});
|
||||
}
|
||||
}
|
||||
pos = parse_space(pos + 1, is_nested);
|
||||
} else if (is_word_char(*pos)) { // rule reference
|
||||
const char * name_end = parse_name(pos);
|
||||
uint32_t ref_rule_id = get_symbol_id(state, pos, name_end - pos);
|
||||
pos = parse_space(name_end, is_nested);
|
||||
last_sym_start = out_elements.size();
|
||||
out_elements.push_back({LLAMA_GRETYPE_RULE_REF, ref_rule_id});
|
||||
} else if (*pos == '(') { // grouping
|
||||
// parse nested alternates into synthesized rule
|
||||
pos = parse_space(pos + 1, true);
|
||||
uint32_t sub_rule_id = generate_symbol_id(state, rule_name);
|
||||
pos = parse_alternates(state, pos, rule_name, sub_rule_id, true);
|
||||
last_sym_start = out_elements.size();
|
||||
// output reference to synthesized rule
|
||||
out_elements.push_back({LLAMA_GRETYPE_RULE_REF, sub_rule_id});
|
||||
if (*pos != ')') {
|
||||
throw std::runtime_error(std::string("expecting ')' at ") + pos);
|
||||
}
|
||||
pos = parse_space(pos + 1, is_nested);
|
||||
} else if (*pos == '.') { // any char
|
||||
last_sym_start = out_elements.size();
|
||||
out_elements.push_back({LLAMA_GRETYPE_CHAR_ANY, 0});
|
||||
pos = parse_space(pos + 1, is_nested);
|
||||
} else if (*pos == '*') {
|
||||
pos = parse_space(pos + 1, is_nested);
|
||||
handle_repetitions(0, -1);
|
||||
} else if (*pos == '+') {
|
||||
pos = parse_space(pos + 1, is_nested);
|
||||
handle_repetitions(1, -1);
|
||||
} else if (*pos == '?') {
|
||||
pos = parse_space(pos + 1, is_nested);
|
||||
handle_repetitions(0, 1);
|
||||
} else if (*pos == '{') {
|
||||
pos = parse_space(pos + 1, is_nested);
|
||||
|
||||
if (!is_digit_char(*pos)) {
|
||||
throw std::runtime_error(std::string("expecting an int at ") + pos);
|
||||
}
|
||||
const char * int_end = parse_int(pos);
|
||||
int min_times = std::stoul(std::string(pos, int_end - pos));
|
||||
pos = parse_space(int_end, is_nested);
|
||||
|
||||
int max_times = -1;
|
||||
|
||||
if (*pos == '}') {
|
||||
max_times = min_times;
|
||||
pos = parse_space(pos + 1, is_nested);
|
||||
} else if (*pos == ',') {
|
||||
pos = parse_space(pos + 1, is_nested);
|
||||
|
||||
if (is_digit_char(*pos)) {
|
||||
const char * int_end = parse_int(pos);
|
||||
max_times = std::stoul(std::string(pos, int_end - pos));
|
||||
pos = parse_space(int_end, is_nested);
|
||||
}
|
||||
|
||||
if (*pos != '}') {
|
||||
throw std::runtime_error(std::string("expecting '}' at ") + pos);
|
||||
}
|
||||
pos = parse_space(pos + 1, is_nested);
|
||||
} else {
|
||||
throw std::runtime_error(std::string("expecting ',' at ") + pos);
|
||||
}
|
||||
handle_repetitions(min_times, max_times);
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
return pos;
|
||||
}
|
||||
|
||||
const char * parse_alternates(
|
||||
parse_state & state,
|
||||
const char * src,
|
||||
const std::string & rule_name,
|
||||
uint32_t rule_id,
|
||||
bool is_nested) {
|
||||
std::vector<llama_grammar_element> rule;
|
||||
const char * pos = parse_sequence(state, src, rule_name, rule, is_nested);
|
||||
while (*pos == '|') {
|
||||
rule.push_back({LLAMA_GRETYPE_ALT, 0});
|
||||
pos = parse_space(pos + 1, true);
|
||||
pos = parse_sequence(state, pos, rule_name, rule, is_nested);
|
||||
}
|
||||
rule.push_back({LLAMA_GRETYPE_END, 0});
|
||||
add_rule(state, rule_id, rule);
|
||||
return pos;
|
||||
}
|
||||
|
||||
static const char * parse_rule(parse_state & state, const char * src) {
|
||||
const char * name_end = parse_name(src);
|
||||
const char * pos = parse_space(name_end, false);
|
||||
size_t name_len = name_end - src;
|
||||
uint32_t rule_id = get_symbol_id(state, src, name_len);
|
||||
const std::string name(src, name_len);
|
||||
|
||||
if (!(pos[0] == ':' && pos[1] == ':' && pos[2] == '=')) {
|
||||
throw std::runtime_error(std::string("expecting ::= at ") + pos);
|
||||
}
|
||||
pos = parse_space(pos + 3, true);
|
||||
|
||||
pos = parse_alternates(state, pos, name, rule_id, false);
|
||||
|
||||
if (*pos == '\r') {
|
||||
pos += pos[1] == '\n' ? 2 : 1;
|
||||
} else if (*pos == '\n') {
|
||||
pos++;
|
||||
} else if (*pos) {
|
||||
throw std::runtime_error(std::string("expecting newline or end at ") + pos);
|
||||
}
|
||||
return parse_space(pos, true);
|
||||
}
|
||||
|
||||
parse_state parse(const char * src) {
|
||||
try {
|
||||
parse_state state;
|
||||
const char * pos = parse_space(src, true);
|
||||
while (*pos) {
|
||||
pos = parse_rule(state, pos);
|
||||
}
|
||||
// Validate the state to ensure that all rules are defined
|
||||
for (const auto & rule : state.rules) {
|
||||
if (rule.empty()) {
|
||||
throw std::runtime_error("Undefined rule");
|
||||
}
|
||||
for (const auto & elem : rule) {
|
||||
if (elem.type == LLAMA_GRETYPE_RULE_REF) {
|
||||
// Ensure that the rule at that location exists
|
||||
if (elem.value >= state.rules.size() || state.rules[elem.value].empty()) {
|
||||
// Get the name of the rule that is missing
|
||||
for (const auto & kv : state.symbol_ids) {
|
||||
if (kv.second == elem.value) {
|
||||
throw std::runtime_error("Undefined rule identifier '" + kv.first + "'");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return state;
|
||||
} catch (const std::exception & err) {
|
||||
fprintf(stderr, "%s: error parsing grammar: %s\n", __func__, err.what());
|
||||
return parse_state();
|
||||
}
|
||||
}
|
||||
|
||||
static void print_grammar_char(FILE * file, uint32_t c) {
|
||||
if (0x20 <= c && c <= 0x7f) {
|
||||
fprintf(file, "%c", static_cast<char>(c));
|
||||
} else {
|
||||
// cop out of encoding UTF-8
|
||||
fprintf(file, "<U+%04X>", c);
|
||||
}
|
||||
}
|
||||
|
||||
static bool is_char_element(llama_grammar_element elem) {
|
||||
switch (elem.type) {
|
||||
case LLAMA_GRETYPE_CHAR: return true;
|
||||
case LLAMA_GRETYPE_CHAR_NOT: return true;
|
||||
case LLAMA_GRETYPE_CHAR_ALT: return true;
|
||||
case LLAMA_GRETYPE_CHAR_RNG_UPPER: return true;
|
||||
case LLAMA_GRETYPE_CHAR_ANY: return true;
|
||||
default: return false;
|
||||
}
|
||||
}
|
||||
|
||||
static void print_rule_binary(FILE * file, const std::vector<llama_grammar_element> & rule) {
|
||||
for (auto elem : rule) {
|
||||
switch (elem.type) {
|
||||
case LLAMA_GRETYPE_END: fprintf(file, "END"); break;
|
||||
case LLAMA_GRETYPE_ALT: fprintf(file, "ALT"); break;
|
||||
case LLAMA_GRETYPE_RULE_REF: fprintf(file, "RULE_REF"); break;
|
||||
case LLAMA_GRETYPE_CHAR: fprintf(file, "CHAR"); break;
|
||||
case LLAMA_GRETYPE_CHAR_NOT: fprintf(file, "CHAR_NOT"); break;
|
||||
case LLAMA_GRETYPE_CHAR_RNG_UPPER: fprintf(file, "CHAR_RNG_UPPER"); break;
|
||||
case LLAMA_GRETYPE_CHAR_ALT: fprintf(file, "CHAR_ALT"); break;
|
||||
case LLAMA_GRETYPE_CHAR_ANY: fprintf(file, "CHAR_ANY"); break;
|
||||
}
|
||||
switch (elem.type) {
|
||||
case LLAMA_GRETYPE_END:
|
||||
case LLAMA_GRETYPE_ALT:
|
||||
case LLAMA_GRETYPE_RULE_REF:
|
||||
fprintf(file, "(%u) ", elem.value);
|
||||
break;
|
||||
case LLAMA_GRETYPE_CHAR:
|
||||
case LLAMA_GRETYPE_CHAR_NOT:
|
||||
case LLAMA_GRETYPE_CHAR_RNG_UPPER:
|
||||
case LLAMA_GRETYPE_CHAR_ALT:
|
||||
case LLAMA_GRETYPE_CHAR_ANY:
|
||||
fprintf(file, "(\"");
|
||||
print_grammar_char(file, elem.value);
|
||||
fprintf(file, "\") ");
|
||||
break;
|
||||
}
|
||||
}
|
||||
fprintf(file, "\n");
|
||||
}
|
||||
|
||||
static void print_rule(
|
||||
FILE * file,
|
||||
uint32_t rule_id,
|
||||
const std::vector<llama_grammar_element> & rule,
|
||||
const std::map<uint32_t, std::string> & symbol_id_names) {
|
||||
if (rule.empty() || rule.back().type != LLAMA_GRETYPE_END) {
|
||||
throw std::runtime_error(
|
||||
"malformed rule, does not end with LLAMA_GRETYPE_END: " + std::to_string(rule_id));
|
||||
}
|
||||
fprintf(file, "%s ::= ", symbol_id_names.at(rule_id).c_str());
|
||||
for (size_t i = 0, end = rule.size() - 1; i < end; i++) {
|
||||
llama_grammar_element elem = rule[i];
|
||||
switch (elem.type) {
|
||||
case LLAMA_GRETYPE_END:
|
||||
throw std::runtime_error(
|
||||
"unexpected end of rule: " + std::to_string(rule_id) + "," +
|
||||
std::to_string(i));
|
||||
case LLAMA_GRETYPE_ALT:
|
||||
fprintf(file, "| ");
|
||||
break;
|
||||
case LLAMA_GRETYPE_RULE_REF:
|
||||
fprintf(file, "%s ", symbol_id_names.at(elem.value).c_str());
|
||||
break;
|
||||
case LLAMA_GRETYPE_CHAR:
|
||||
fprintf(file, "[");
|
||||
print_grammar_char(file, elem.value);
|
||||
break;
|
||||
case LLAMA_GRETYPE_CHAR_NOT:
|
||||
fprintf(file, "[^");
|
||||
print_grammar_char(file, elem.value);
|
||||
break;
|
||||
case LLAMA_GRETYPE_CHAR_RNG_UPPER:
|
||||
if (i == 0 || !is_char_element(rule[i - 1])) {
|
||||
throw std::runtime_error(
|
||||
"LLAMA_GRETYPE_CHAR_RNG_UPPER without preceding char: " +
|
||||
std::to_string(rule_id) + "," + std::to_string(i));
|
||||
}
|
||||
fprintf(file, "-");
|
||||
print_grammar_char(file, elem.value);
|
||||
break;
|
||||
case LLAMA_GRETYPE_CHAR_ALT:
|
||||
if (i == 0 || !is_char_element(rule[i - 1])) {
|
||||
throw std::runtime_error(
|
||||
"LLAMA_GRETYPE_CHAR_ALT without preceding char: " +
|
||||
std::to_string(rule_id) + "," + std::to_string(i));
|
||||
}
|
||||
print_grammar_char(file, elem.value);
|
||||
break;
|
||||
case LLAMA_GRETYPE_CHAR_ANY:
|
||||
fprintf(file, ".");
|
||||
break;
|
||||
}
|
||||
if (is_char_element(elem)) {
|
||||
switch (rule[i + 1].type) {
|
||||
case LLAMA_GRETYPE_CHAR_ALT:
|
||||
case LLAMA_GRETYPE_CHAR_RNG_UPPER:
|
||||
case LLAMA_GRETYPE_CHAR_ANY:
|
||||
break;
|
||||
default:
|
||||
fprintf(file, "] ");
|
||||
}
|
||||
}
|
||||
}
|
||||
fprintf(file, "\n");
|
||||
}
|
||||
|
||||
void print_grammar(FILE * file, const parse_state & state) {
|
||||
try {
|
||||
std::map<uint32_t, std::string> symbol_id_names;
|
||||
for (const auto & kv : state.symbol_ids) {
|
||||
symbol_id_names[kv.second] = kv.first;
|
||||
}
|
||||
for (size_t i = 0, end = state.rules.size(); i < end; i++) {
|
||||
// fprintf(file, "%zu: ", i);
|
||||
// print_rule_binary(file, state.rules[i]);
|
||||
print_rule(file, uint32_t(i), state.rules[i], symbol_id_names);
|
||||
// fprintf(file, "\n");
|
||||
}
|
||||
} catch (const std::exception & err) {
|
||||
fprintf(stderr, "\n%s: error printing grammar: %s\n", __func__, err.what());
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<const llama_grammar_element *> parse_state::c_rules() {
|
||||
std::vector<const llama_grammar_element *> ret;
|
||||
ret.reserve(rules.size());
|
||||
for (const auto & rule : rules) {
|
||||
ret.push_back(rule.data());
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
}
|
|
@ -1,29 +0,0 @@
|
|||
// Implements a parser for an extended Backus-Naur form (BNF), producing the
|
||||
// binary context-free grammar format specified by llama.h. Supports character
|
||||
// ranges, grouping, and repetition operators. As an example, a grammar for
|
||||
// arithmetic might look like:
|
||||
//
|
||||
// root ::= expr
|
||||
// expr ::= term ([-+*/] term)*
|
||||
// term ::= num | "(" space expr ")" space
|
||||
// num ::= [0-9]+ space
|
||||
// space ::= [ \t\n]*
|
||||
|
||||
#pragma once
|
||||
#include "llama.h"
|
||||
#include <vector>
|
||||
#include <map>
|
||||
#include <cstdint>
|
||||
#include <string>
|
||||
|
||||
namespace grammar_parser {
|
||||
struct parse_state {
|
||||
std::map<std::string, uint32_t> symbol_ids;
|
||||
std::vector<std::vector<llama_grammar_element>> rules;
|
||||
|
||||
std::vector<const llama_grammar_element *> c_rules();
|
||||
};
|
||||
|
||||
parse_state parse(const char * src);
|
||||
void print_grammar(FILE * file, const parse_state & state);
|
||||
}
|
|
@ -1,99 +1,53 @@
|
|||
#define LLAMA_API_INTERNAL
|
||||
#include "sampling.h"
|
||||
|
||||
#include <random>
|
||||
|
||||
struct llama_sampling_context * llama_sampling_init(const struct llama_sampling_params & params) {
|
||||
struct llama_sampling_context * llama_sampling_init(const struct llama_sampling_params & params, const struct llama_model * model) {
|
||||
auto result = llama_sampling_init(params, llama_sampling_init(model, params.grammar.c_str(), "root"));
|
||||
|
||||
result->owned = true;
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
struct llama_sampling_context * llama_sampling_init(const struct llama_sampling_params & params, struct llama_sampling * smpl) {
|
||||
struct llama_sampling_context * result = new llama_sampling_context();
|
||||
|
||||
result->params = params;
|
||||
result->grammar = nullptr;
|
||||
|
||||
// if there is a grammar, parse it
|
||||
if (!params.grammar.empty()) {
|
||||
result->parsed_grammar = grammar_parser::parse(params.grammar.c_str());
|
||||
|
||||
// will be empty (default) if there are parse errors
|
||||
if (result->parsed_grammar.rules.empty()) {
|
||||
fprintf(stderr, "%s: failed to parse grammar\n", __func__);
|
||||
delete result;
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// Ensure that there is a "root" node.
|
||||
if (result->parsed_grammar.symbol_ids.find("root") == result->parsed_grammar.symbol_ids.end()) {
|
||||
fprintf(stderr, "%s: grammar does not contain a 'root' symbol\n", __func__);
|
||||
delete result;
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
std::vector<const llama_grammar_element *> grammar_rules(result->parsed_grammar.c_rules());
|
||||
|
||||
struct llama_grammar * grammar = llama_grammar_init(
|
||||
grammar_rules.data(),
|
||||
grammar_rules.size(), result->parsed_grammar.symbol_ids.at("root"));
|
||||
if (grammar == nullptr) {
|
||||
throw std::runtime_error("Failed to initialize llama_grammar");
|
||||
}
|
||||
result->grammar = grammar;
|
||||
}
|
||||
result->owned = false;
|
||||
result->smpl = smpl;
|
||||
|
||||
result->prev.resize(params.n_prev);
|
||||
|
||||
result->n_valid = 0;
|
||||
|
||||
llama_sampling_set_rng_seed(result, params.seed);
|
||||
llama_sampling_set_rng_seed(result->smpl, params.seed);
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
void llama_sampling_free(struct llama_sampling_context * ctx) {
|
||||
if (ctx->grammar != NULL) {
|
||||
llama_grammar_free(ctx->grammar);
|
||||
if (ctx->owned) {
|
||||
llama_sampling_free(ctx->smpl);
|
||||
}
|
||||
|
||||
delete ctx;
|
||||
}
|
||||
|
||||
void llama_sampling_reset(llama_sampling_context * ctx) {
|
||||
if (ctx->grammar != NULL) {
|
||||
llama_grammar_free(ctx->grammar);
|
||||
ctx->grammar = NULL;
|
||||
}
|
||||
|
||||
if (!ctx->parsed_grammar.rules.empty()) {
|
||||
std::vector<const llama_grammar_element *> grammar_rules(ctx->parsed_grammar.c_rules());
|
||||
|
||||
struct llama_grammar * grammar = llama_grammar_init(
|
||||
grammar_rules.data(),
|
||||
grammar_rules.size(), ctx->parsed_grammar.symbol_ids.at("root"));
|
||||
if (grammar == nullptr) {
|
||||
throw std::runtime_error("Failed to initialize llama_grammar");
|
||||
}
|
||||
ctx->grammar = grammar;
|
||||
}
|
||||
llama_sampling_reset(ctx->smpl, ctx->params.grammar.c_str(), "root");
|
||||
|
||||
std::fill(ctx->prev.begin(), ctx->prev.end(), 0);
|
||||
ctx->cur.clear();
|
||||
ctx->n_valid = 0;
|
||||
}
|
||||
|
||||
void llama_sampling_set_rng_seed(struct llama_sampling_context * ctx, uint32_t seed) {
|
||||
if (seed == LLAMA_DEFAULT_SEED) {
|
||||
seed = std::random_device{}();
|
||||
}
|
||||
ctx->rng.seed(seed);
|
||||
}
|
||||
|
||||
void llama_sampling_cp(llama_sampling_context * src, llama_sampling_context * dst) {
|
||||
if (dst->grammar) {
|
||||
llama_grammar_free(dst->grammar);
|
||||
dst->grammar = nullptr;
|
||||
}
|
||||
|
||||
if (src->grammar) {
|
||||
dst->grammar = llama_grammar_copy(src->grammar);
|
||||
if (dst->smpl) {
|
||||
llama_sampling_free(dst->smpl);
|
||||
}
|
||||
|
||||
dst->smpl = llama_sampling_cp(src->smpl);
|
||||
dst->prev = src->prev;
|
||||
}
|
||||
|
||||
|
@ -228,10 +182,13 @@ std::vector<llama_sampler_type> llama_sampling_types_from_chars(const std::strin
|
|||
|
||||
// no reasons to expose this function in header
|
||||
static void sampler_queue(
|
||||
struct llama_context * ctx_main,
|
||||
const llama_sampling_params & params,
|
||||
struct llama_sampling_context * ctx_sampling,
|
||||
llama_token_data_array & cur_p,
|
||||
size_t min_keep) {
|
||||
llama_sampling * smpl = ctx_sampling->smpl;
|
||||
|
||||
const llama_sampling_params & params = ctx_sampling->params;
|
||||
|
||||
const float temp = params.temp;
|
||||
const float dynatemp_range = params.dynatemp_range;
|
||||
const float dynatemp_exponent = params.dynatemp_exponent;
|
||||
|
@ -244,18 +201,18 @@ static void sampler_queue(
|
|||
|
||||
for (auto sampler_type : samplers_sequence) {
|
||||
switch (sampler_type) {
|
||||
case llama_sampler_type::TOP_K : llama_sample_top_k (ctx_main, &cur_p, top_k, min_keep); break;
|
||||
case llama_sampler_type::TFS_Z : llama_sample_tail_free(ctx_main, &cur_p, tfs_z, min_keep); break;
|
||||
case llama_sampler_type::TYPICAL_P: llama_sample_typical (ctx_main, &cur_p, typical_p, min_keep); break;
|
||||
case llama_sampler_type::TOP_P : llama_sample_top_p (ctx_main, &cur_p, top_p, min_keep); break;
|
||||
case llama_sampler_type::MIN_P : llama_sample_min_p (ctx_main, &cur_p, min_p, min_keep); break;
|
||||
case llama_sampler_type::TOP_K : llama_sampling_top_k (smpl, &cur_p, top_k, min_keep); break;
|
||||
case llama_sampler_type::TFS_Z : llama_sampling_tail_free(smpl, &cur_p, tfs_z, min_keep); break;
|
||||
case llama_sampler_type::TYPICAL_P: llama_sampling_typical (smpl, &cur_p, typical_p, min_keep); break;
|
||||
case llama_sampler_type::TOP_P : llama_sampling_top_p (smpl, &cur_p, top_p, min_keep); break;
|
||||
case llama_sampler_type::MIN_P : llama_sampling_min_p (smpl, &cur_p, min_p, min_keep); break;
|
||||
case llama_sampler_type::TEMPERATURE:
|
||||
if (dynatemp_range > 0) {
|
||||
float dynatemp_min = std::max(0.0f, temp - dynatemp_range);
|
||||
float dynatemp_max = std::max(0.0f, temp + dynatemp_range);
|
||||
llama_sample_entropy(ctx_main, &cur_p, dynatemp_min, dynatemp_max, dynatemp_exponent);
|
||||
llama_sampling_entropy(smpl, &cur_p, dynatemp_min, dynatemp_max, dynatemp_exponent);
|
||||
} else {
|
||||
llama_sample_temp(ctx_main, &cur_p, temp);
|
||||
llama_sampling_temp(smpl, &cur_p, temp);
|
||||
}
|
||||
break;
|
||||
default : break;
|
||||
|
@ -269,6 +226,8 @@ static llama_token llama_sampling_sample_impl(
|
|||
struct llama_context * ctx_cfg,
|
||||
const int idx,
|
||||
bool is_resampling) {
|
||||
llama_sampling * smpl = ctx_sampling->smpl;
|
||||
|
||||
const llama_sampling_params & params = ctx_sampling->params;
|
||||
|
||||
const float temp = params.temp;
|
||||
|
@ -278,33 +237,33 @@ static llama_token llama_sampling_sample_impl(
|
|||
|
||||
std::vector<float> original_logits;
|
||||
auto cur_p = llama_sampling_prepare(ctx_sampling, ctx_main, ctx_cfg, idx, /* apply_grammar= */ is_resampling, &original_logits);
|
||||
if (ctx_sampling->grammar != NULL && !is_resampling) {
|
||||
if (!is_resampling) {
|
||||
GGML_ASSERT(!original_logits.empty());
|
||||
}
|
||||
llama_token id = 0;
|
||||
|
||||
if (temp < 0.0) {
|
||||
// greedy sampling, with probs
|
||||
llama_sample_softmax(ctx_main, &cur_p);
|
||||
llama_sampling_softmax(smpl, &cur_p);
|
||||
id = cur_p.data[0].id;
|
||||
} else if (temp == 0.0) {
|
||||
// greedy sampling, no probs
|
||||
id = llama_sample_token_greedy(ctx_main, &cur_p);
|
||||
id = llama_sampling_sample_greedy(smpl, &cur_p);
|
||||
} else {
|
||||
if (mirostat == 1) {
|
||||
const int mirostat_m = 100;
|
||||
llama_sample_temp(ctx_main, &cur_p, temp);
|
||||
id = llama_sample_token_mirostat(ctx_main, &cur_p, mirostat_tau, mirostat_eta, mirostat_m, &ctx_sampling->mirostat_mu);
|
||||
llama_sampling_temp(smpl, &cur_p, temp);
|
||||
id = llama_sampling_sample_mirostat(smpl, &cur_p, mirostat_tau, mirostat_eta, mirostat_m, &ctx_sampling->mirostat_mu);
|
||||
} else if (mirostat == 2) {
|
||||
llama_sample_temp(ctx_main, &cur_p, temp);
|
||||
id = llama_sample_token_mirostat_v2(ctx_main, &cur_p, mirostat_tau, mirostat_eta, &ctx_sampling->mirostat_mu);
|
||||
llama_sampling_temp(smpl, &cur_p, temp);
|
||||
id = llama_sampling_sample_mirostat_v2(smpl, &cur_p, mirostat_tau, mirostat_eta, &ctx_sampling->mirostat_mu);
|
||||
} else {
|
||||
// temperature sampling
|
||||
size_t min_keep = std::max(1, params.min_keep);
|
||||
|
||||
sampler_queue(ctx_main, params, cur_p, min_keep);
|
||||
sampler_queue(ctx_sampling, cur_p, min_keep);
|
||||
|
||||
id = llama_sample_token_with_rng(ctx_main, &cur_p, ctx_sampling->rng);
|
||||
id = llama_sampling_sample(smpl, &cur_p);
|
||||
|
||||
//{
|
||||
// const int n_top = 10;
|
||||
|
@ -313,15 +272,15 @@ static llama_token llama_sampling_sample_impl(
|
|||
// for (int i = 0; i < n_top; i++) {
|
||||
// const llama_token id = cur_p.data[i].id;
|
||||
// (void)id; // To avoid a warning that id is unused when logging is disabled.
|
||||
// LOG(" - %5d: '%12s' (%.3f)\n", id, llama_token_to_piece(ctx_main, id).c_str(), cur_p.data[i].p);
|
||||
// LOG(" - %5d: '%12s' (%.3f)\n", id, llama_token_to_piece(smpl, id).c_str(), cur_p.data[i].p);
|
||||
// }
|
||||
//}
|
||||
|
||||
//LOG("sampled token: %5d: '%s'\n", id, llama_token_to_piece(ctx_main, id).c_str());
|
||||
//LOG("sampled token: %5d: '%s'\n", id, llama_token_to_piece(smpl, id).c_str());
|
||||
}
|
||||
}
|
||||
|
||||
if (ctx_sampling->grammar != NULL && !is_resampling) {
|
||||
if (!is_resampling) {
|
||||
// Get a pointer to the logits
|
||||
float * logits = llama_get_logits_ith(ctx_main, idx);
|
||||
|
||||
|
@ -330,7 +289,7 @@ static llama_token llama_sampling_sample_impl(
|
|||
llama_token_data_array single_token_data_array = { &single_token_data, 1, false };
|
||||
|
||||
// Apply grammar constraints to the single token
|
||||
llama_grammar_sample(ctx_sampling->grammar, ctx_main, &single_token_data_array);
|
||||
llama_sampling_grammar(ctx_sampling->smpl, &single_token_data_array);
|
||||
|
||||
// Check if the token is valid according to the grammar by seeing if its logit has been set to -INFINITY
|
||||
bool is_valid = single_token_data_array.data[0].logit != -INFINITY;
|
||||
|
@ -358,6 +317,8 @@ static llama_token_data_array llama_sampling_prepare_impl(
|
|||
const int idx,
|
||||
bool apply_grammar,
|
||||
std::vector<float> * original_logits) {
|
||||
llama_sampling * smpl = ctx_sampling->smpl;
|
||||
|
||||
const llama_sampling_params & params = ctx_sampling->params;
|
||||
|
||||
const int n_vocab = llama_n_vocab(llama_get_model(ctx_main));
|
||||
|
@ -375,7 +336,7 @@ static llama_token_data_array llama_sampling_prepare_impl(
|
|||
// Get a pointer to the logits
|
||||
float * logits = llama_get_logits_ith(ctx_main, idx);
|
||||
|
||||
if (ctx_sampling->grammar != NULL && !apply_grammar) {
|
||||
if (!apply_grammar) {
|
||||
GGML_ASSERT(original_logits != NULL);
|
||||
// Only make a copy of the original logits if we are not applying grammar checks, not sure if I actually have to do this.
|
||||
*original_logits = {logits, logits + n_vocab};
|
||||
|
@ -388,7 +349,7 @@ static llama_token_data_array llama_sampling_prepare_impl(
|
|||
|
||||
if (ctx_cfg) {
|
||||
float * logits_guidance = llama_get_logits_ith(ctx_cfg, idx);
|
||||
llama_sample_apply_guidance(ctx_main, logits, logits_guidance, params.cfg_scale);
|
||||
llama_sampling_apply_guidance(smpl, logits, logits_guidance, params.cfg_scale);
|
||||
}
|
||||
|
||||
cur.resize(n_vocab);
|
||||
|
@ -405,7 +366,7 @@ static llama_token_data_array llama_sampling_prepare_impl(
|
|||
if (penalty_tokens_used_size) {
|
||||
const float nl_logit = logits[llama_token_nl(llama_get_model(ctx_main))];
|
||||
|
||||
llama_sample_repetition_penalties(ctx_main, &cur_p,
|
||||
llama_sampling_repetition_penalties(smpl, &cur_p,
|
||||
penalty_tokens.data() + penalty_tokens.size() - penalty_tokens_used_size,
|
||||
penalty_tokens_used_size, penalty_repeat, penalty_freq, penalty_present);
|
||||
|
||||
|
@ -420,8 +381,8 @@ static llama_token_data_array llama_sampling_prepare_impl(
|
|||
}
|
||||
|
||||
// apply grammar checks before sampling logic
|
||||
if (apply_grammar && ctx_sampling->grammar != NULL) {
|
||||
llama_grammar_sample(ctx_sampling->grammar, ctx_main, &cur_p);
|
||||
if (apply_grammar) {
|
||||
llama_sampling_grammar(ctx_sampling->smpl, &cur_p);
|
||||
}
|
||||
|
||||
return cur_p;
|
||||
|
@ -443,18 +404,17 @@ llama_token_data_array llama_sampling_prepare(
|
|||
const int idx,
|
||||
bool apply_grammar,
|
||||
std::vector<float> * original_logits) {
|
||||
return llama_sampling_prepare_impl(ctx_sampling,ctx_main, ctx_cfg, idx, apply_grammar, original_logits);
|
||||
return llama_sampling_prepare_impl(ctx_sampling, ctx_main, ctx_cfg, idx, apply_grammar, original_logits);
|
||||
}
|
||||
|
||||
void llama_sampling_accept(
|
||||
struct llama_sampling_context * ctx_sampling,
|
||||
struct llama_context * ctx_main,
|
||||
llama_token id,
|
||||
bool apply_grammar) {
|
||||
ctx_sampling->prev.erase(ctx_sampling->prev.begin());
|
||||
ctx_sampling->prev.push_back(id);
|
||||
|
||||
if (ctx_sampling->grammar != NULL && apply_grammar) {
|
||||
llama_grammar_accept_token(ctx_sampling->grammar, ctx_main, id);
|
||||
if (apply_grammar) {
|
||||
llama_sampling_accept(ctx_sampling->smpl, id);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -2,8 +2,6 @@
|
|||
|
||||
#include "llama.h"
|
||||
|
||||
#include "grammar-parser.h"
|
||||
|
||||
#include <random>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
|
@ -73,23 +71,22 @@ struct llama_sampling_context {
|
|||
// mirostat sampler state
|
||||
float mirostat_mu;
|
||||
|
||||
llama_grammar * grammar;
|
||||
bool owned;
|
||||
|
||||
// internal
|
||||
grammar_parser::parse_state parsed_grammar;
|
||||
llama_sampling * smpl;
|
||||
|
||||
// TODO: replace with ring-buffer
|
||||
std::vector<llama_token> prev;
|
||||
std::vector<llama_token_data> cur;
|
||||
size_t n_valid; // Number of correct top tokens with correct probabilities.
|
||||
|
||||
std::mt19937 rng;
|
||||
size_t n_valid; // Number of correct top tokens with correct probabilities.
|
||||
};
|
||||
|
||||
#include "common.h"
|
||||
|
||||
// Create a new sampling context instance.
|
||||
struct llama_sampling_context * llama_sampling_init(const struct llama_sampling_params & params);
|
||||
struct llama_sampling_context * llama_sampling_init(const struct llama_sampling_params & params, const struct llama_model * model);
|
||||
struct llama_sampling_context * llama_sampling_init(const struct llama_sampling_params & params, struct llama_sampling * smpl);
|
||||
|
||||
void llama_sampling_free(struct llama_sampling_context * ctx);
|
||||
|
||||
|
@ -98,9 +95,6 @@ void llama_sampling_free(struct llama_sampling_context * ctx);
|
|||
// - reset grammar
|
||||
void llama_sampling_reset(llama_sampling_context * ctx);
|
||||
|
||||
// Set the sampler seed
|
||||
void llama_sampling_set_rng_seed(struct llama_sampling_context * ctx, uint32_t seed);
|
||||
|
||||
// Copy the sampler context
|
||||
void llama_sampling_cp(llama_sampling_context * src, llama_sampling_context * dst);
|
||||
|
||||
|
@ -155,6 +149,5 @@ llama_token_data_array llama_sampling_prepare(
|
|||
|
||||
void llama_sampling_accept(
|
||||
struct llama_sampling_context * ctx_sampling,
|
||||
struct llama_context * ctx_main,
|
||||
llama_token id,
|
||||
bool apply_grammar);
|
||||
|
|
|
@ -200,7 +200,7 @@ int main(int argc, char ** argv) {
|
|||
}
|
||||
}
|
||||
|
||||
llama_print_timings(ctx);
|
||||
llama_print_timings(ctx, nullptr);
|
||||
|
||||
llama_batch_free(batch);
|
||||
|
||||
|
|
|
@ -44,6 +44,8 @@ context_params.n_threads = 8
|
|||
context_params.n_threads_batch = 8
|
||||
|
||||
let context = llama_new_context_with_model(model, context_params)
|
||||
let smpl = llama_get_sampling(context)
|
||||
|
||||
guard context != nil else {
|
||||
print("Failed to initialize context")
|
||||
exit(1)
|
||||
|
@ -144,13 +146,13 @@ while n_cur <= n_len {
|
|||
let top_p: Float = 0.9
|
||||
let temp: Float = 0.4
|
||||
|
||||
llama_sample_top_k(context, &candidates_p, top_k, 1)
|
||||
llama_sample_top_p(context, &candidates_p, top_p, 1)
|
||||
llama_sample_temp(context, &candidates_p, temp)
|
||||
llama_sampling_top_k(smpl, &candidates_p, top_k, 1)
|
||||
llama_sampling_top_p(smpl, &candidates_p, top_p, 1)
|
||||
llama_sampling_temp(smpl, &candidates_p, temp)
|
||||
|
||||
let new_token_id = llama_sample_token(context, &candidates_p)
|
||||
let new_token_id = llama_sampling_sample(smpl, &candidates_p)
|
||||
|
||||
// const llama_token new_token_id = llama_sample_token_greedy(ctx, &candidates_p);
|
||||
// const llama_token new_token_id = llama_sampling_sample_greedy(smpl, &candidates_p);
|
||||
|
||||
// is it an end of stream? -> mark the stream as finished
|
||||
if llama_token_is_eog(model, new_token_id) || n_cur == n_len {
|
||||
|
@ -212,7 +214,7 @@ let t_main_end = ggml_time_us()
|
|||
|
||||
print("decoded \(n_decode) tokens in \(String(format: "%.2f", Double(t_main_end - t_main_start) / 1_000_000.0)) s, speed: \(String(format: "%.2f", Double(n_decode) / (Double(t_main_end - t_main_start) / 1_000_000.0))) t/s\n")
|
||||
|
||||
llama_print_timings(context)
|
||||
llama_print_timings(context, smpl)
|
||||
|
||||
private func tokenize(text: String, add_bos: Bool) -> [llama_token] {
|
||||
let utf8Count = text.utf8.count
|
||||
|
|
|
@ -64,6 +64,7 @@ int main(int argc, char ** argv) {
|
|||
ctx_params.n_batch = std::max(n_predict, n_parallel);
|
||||
|
||||
llama_context * ctx = llama_new_context_with_model(model, ctx_params);
|
||||
llama_sampling * smpl = llama_get_sampling(ctx);
|
||||
|
||||
if (ctx == NULL) {
|
||||
fprintf(stderr , "%s: error: failed to create the llama_context\n" , __func__);
|
||||
|
@ -180,13 +181,13 @@ int main(int argc, char ** argv) {
|
|||
const float top_p = 0.9f;
|
||||
const float temp = 0.4f;
|
||||
|
||||
llama_sample_top_k(ctx, &candidates_p, top_k, 1);
|
||||
llama_sample_top_p(ctx, &candidates_p, top_p, 1);
|
||||
llama_sample_temp (ctx, &candidates_p, temp);
|
||||
llama_sampling_top_k(smpl, &candidates_p, top_k, 1);
|
||||
llama_sampling_top_p(smpl, &candidates_p, top_p, 1);
|
||||
llama_sampling_temp (smpl, &candidates_p, temp);
|
||||
|
||||
const llama_token new_token_id = llama_sample_token(ctx, &candidates_p);
|
||||
const llama_token new_token_id = llama_sampling_sample(smpl, &candidates_p);
|
||||
|
||||
//const llama_token new_token_id = llama_sample_token_greedy(ctx, &candidates_p);
|
||||
//const llama_token new_token_id = llama_sampling_sample_greedy(smpl, &candidates_p);
|
||||
|
||||
// is it an end of generation? -> mark the stream as finished
|
||||
if (llama_token_is_eog(model, new_token_id) || n_cur == n_predict) {
|
||||
|
@ -244,12 +245,13 @@ int main(int argc, char ** argv) {
|
|||
LOG_TEE("%s: decoded %d tokens in %.2f s, speed: %.2f t/s\n",
|
||||
__func__, n_decode, (t_main_end - t_main_start) / 1000000.0f, n_decode / ((t_main_end - t_main_start) / 1000000.0f));
|
||||
|
||||
llama_print_timings(ctx);
|
||||
llama_print_timings(ctx, smpl);
|
||||
|
||||
fprintf(stderr, "\n");
|
||||
|
||||
llama_batch_free(batch);
|
||||
|
||||
llama_sampling_free(smpl);
|
||||
llama_free(ctx);
|
||||
llama_free_model(model);
|
||||
|
||||
|
|
|
@ -314,7 +314,7 @@ int main(int argc, char ** argv) {
|
|||
}
|
||||
|
||||
// clean up
|
||||
llama_print_timings(ctx);
|
||||
llama_print_timings(ctx, nullptr);
|
||||
llama_batch_free(batch);
|
||||
llama_free(ctx);
|
||||
llama_free_model(model);
|
||||
|
|
|
@ -183,7 +183,7 @@ int main(int argc, char ** argv) {
|
|||
return 1;
|
||||
}
|
||||
|
||||
llama_print_timings(ctx);
|
||||
llama_print_timings(ctx, nullptr);
|
||||
|
||||
llama_free(ctx);
|
||||
llama_free_model(model);
|
||||
|
|
|
@ -1,8 +1,7 @@
|
|||
#define LLAMA_API_INTERNAL
|
||||
|
||||
#include "grammar-parser.h"
|
||||
#include "ggml.h"
|
||||
#include "llama.h"
|
||||
#include "llama-vocab.h" // TMP
|
||||
#include "llama-grammar.h"
|
||||
#include "unicode.h"
|
||||
|
||||
#include <cstdio>
|
||||
|
@ -85,27 +84,8 @@ int main(int argc, char** argv) {
|
|||
grammar_str = buffer.str();
|
||||
}
|
||||
|
||||
// Parse the GBNF grammar
|
||||
auto parsed_grammar = grammar_parser::parse(grammar_str.c_str());
|
||||
|
||||
// will be empty (default) if there are parse errors
|
||||
if (parsed_grammar.rules.empty()) {
|
||||
fprintf(stdout, "%s: failed to parse grammar\n", __func__);
|
||||
return 1;
|
||||
}
|
||||
|
||||
// Ensure that there is a "root" node.
|
||||
if (parsed_grammar.symbol_ids.find("root") == parsed_grammar.symbol_ids.end()) {
|
||||
fprintf(stdout, "%s: grammar does not contain a 'root' symbol\n", __func__);
|
||||
return 1;
|
||||
}
|
||||
|
||||
std::vector<const llama_grammar_element *> grammar_rules(parsed_grammar.c_rules());
|
||||
|
||||
// Create the LLAMA grammar
|
||||
auto grammar = llama_grammar_init(
|
||||
grammar_rules.data(),
|
||||
grammar_rules.size(), parsed_grammar.symbol_ids.at("root"));
|
||||
llama_vocab vocab; // TMP
|
||||
llama_grammar * grammar = llama_grammar_init_impl(vocab, grammar_str.c_str(), "root");
|
||||
if (grammar == nullptr) {
|
||||
throw std::runtime_error("Failed to initialize llama_grammar");
|
||||
}
|
||||
|
@ -131,7 +111,7 @@ int main(int argc, char** argv) {
|
|||
}
|
||||
|
||||
// Clean up
|
||||
llama_grammar_free(grammar);
|
||||
llama_grammar_free_impl(grammar);
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
|
|
@ -9,7 +9,7 @@
|
|||
static std::vector<std::vector<float>> encode(llama_context * ctx, const std::vector<std::string> & sentences, const std::string & instruction) {
|
||||
std::vector<std::vector<float>> result;
|
||||
|
||||
const llama_model * mdl = llama_get_model(ctx);
|
||||
const llama_model * model = llama_get_model(ctx);
|
||||
|
||||
llama_batch batch = llama_batch_init(llama_n_batch(ctx), 0, 1);
|
||||
|
||||
|
@ -18,16 +18,16 @@ static std::vector<std::vector<float>> encode(llama_context * ctx, const std::ve
|
|||
|
||||
const std::string input_string = instruction + sentences[i];
|
||||
|
||||
std::vector<llama_token> inputs = llama_tokenize(mdl, input_string, true, false);
|
||||
std::vector<llama_token> inputs = llama_tokenize(model, input_string, true, false);
|
||||
|
||||
const int32_t n_toks = inputs.size();
|
||||
|
||||
// GritLM seems to have EOS = ""
|
||||
// https://github.com/ContextualAI/gritlm/blob/92025b16534712b31b3c4aaaf069350e222bd5f8/gritlm/gritlm.py#L18
|
||||
// inputs.push_back(llama_token_eos(mdl));
|
||||
// inputs.push_back(llama_token_eos(model));
|
||||
|
||||
// we want to ignore instruction tokens for mean pooling
|
||||
const int32_t n_inst = llama_tokenize(mdl, instruction, true, false).size();
|
||||
const int32_t n_inst = llama_tokenize(model, instruction, true, false).size();
|
||||
|
||||
#ifdef GRIT_DEBUG
|
||||
// debug tokens - should be matching as referenced in the GritLM sample
|
||||
|
@ -51,7 +51,7 @@ static std::vector<std::vector<float>> encode(llama_context * ctx, const std::ve
|
|||
llama_decode(ctx, batch);
|
||||
|
||||
// get embedding dimensions
|
||||
uint64_t n_embd = llama_n_embd(mdl);
|
||||
uint64_t n_embd = llama_n_embd(model);
|
||||
|
||||
// allocate embedding output
|
||||
std::vector<float> emb_unorm(n_embd, 0.0f);
|
||||
|
@ -95,8 +95,9 @@ static std::vector<std::vector<float>> encode(llama_context * ctx, const std::ve
|
|||
static std::string generate(llama_context * ctx, const std::string & prompt, bool stream) {
|
||||
std::string result;
|
||||
|
||||
const llama_model * mdl = llama_get_model(ctx);
|
||||
llama_token eos_token = llama_token_eos(mdl);
|
||||
const llama_model * model = llama_get_model(ctx);
|
||||
llama_sampling * smpl = llama_get_sampling(ctx);
|
||||
llama_token eos_token = llama_token_eos(model);
|
||||
|
||||
llama_kv_cache_clear(ctx);
|
||||
llama_set_embeddings(ctx, false);
|
||||
|
@ -104,7 +105,7 @@ static std::string generate(llama_context * ctx, const std::string & prompt, boo
|
|||
|
||||
llama_batch bat = llama_batch_init(llama_n_batch(ctx), 0, 1);
|
||||
|
||||
std::vector<llama_token> inputs = llama_tokenize(mdl, prompt, false, true);
|
||||
std::vector<llama_token> inputs = llama_tokenize(model, prompt, false, true);
|
||||
int32_t i_current_token = 0;
|
||||
|
||||
while (true) {
|
||||
|
@ -118,14 +119,14 @@ static std::string generate(llama_context * ctx, const std::string & prompt, boo
|
|||
llama_decode(ctx, bat);
|
||||
auto logits = llama_get_logits_ith(ctx, bat.n_tokens - 1);
|
||||
|
||||
auto candidates = std::vector<llama_token_data>(llama_n_vocab(mdl));
|
||||
auto candidates = std::vector<llama_token_data>(llama_n_vocab(model));
|
||||
auto n_candidates = (int32_t)candidates.size();
|
||||
for (int32_t token = 0; token < n_candidates; token++) {
|
||||
candidates[token] = llama_token_data{ token, logits[token], 0.0f };
|
||||
}
|
||||
auto candidates_p = llama_token_data_array{ candidates.data(), candidates.size(), false };
|
||||
|
||||
llama_token token = llama_sample_token_greedy(ctx, &candidates_p);
|
||||
llama_token token = llama_sampling_sample_greedy(smpl, &candidates_p);
|
||||
if (token == eos_token) {
|
||||
break;
|
||||
}
|
||||
|
@ -167,10 +168,10 @@ int main(int argc, char * argv[]) {
|
|||
|
||||
llama_backend_init();
|
||||
|
||||
llama_model * mdl = llama_load_model_from_file(params.model.c_str(), mparams);
|
||||
llama_model * model = llama_load_model_from_file(params.model.c_str(), mparams);
|
||||
|
||||
// create generation context
|
||||
llama_context * ctx = llama_new_context_with_model(mdl, cparams);
|
||||
llama_context * ctx = llama_new_context_with_model(model, cparams);
|
||||
|
||||
// ### Embedding/Representation ###
|
||||
// samples taken from: https://github.com/ContextualAI/gritlm#basic
|
||||
|
@ -191,7 +192,7 @@ int main(int argc, char * argv[]) {
|
|||
const std::vector<std::vector<float>> d_rep = encode(ctx, documents, gritlm_instruction(""));
|
||||
const std::vector<std::vector<float>> q_rep = encode(ctx, queries, gritlm_instruction(instruction));
|
||||
|
||||
const int n_embd = llama_n_embd(mdl);
|
||||
const int n_embd = llama_n_embd(model);
|
||||
|
||||
const float cosine_sim_q0_d0 = llama_embd_similarity_cos(q_rep[0].data(), d_rep[0].data(), n_embd);
|
||||
const float cosine_sim_q0_d1 = llama_embd_similarity_cos(q_rep[0].data(), d_rep[1].data(), n_embd);
|
||||
|
@ -212,7 +213,7 @@ int main(int argc, char * argv[]) {
|
|||
}
|
||||
|
||||
llama_free(ctx);
|
||||
llama_free_model(mdl);
|
||||
llama_free_model(model);
|
||||
llama_backend_free();
|
||||
|
||||
return 0;
|
||||
|
|
|
@ -638,7 +638,7 @@ int main(int argc, char ** argv) {
|
|||
|
||||
g_collector.save_imatrix();
|
||||
|
||||
llama_print_timings(ctx);
|
||||
llama_print_timings(ctx, nullptr);
|
||||
|
||||
llama_free(ctx);
|
||||
llama_free_model(model);
|
||||
|
|
|
@ -2,7 +2,6 @@
|
|||
|
||||
#include "console.h"
|
||||
#include "llama.h"
|
||||
#include "grammar-parser.h"
|
||||
|
||||
#include <cassert>
|
||||
#include <cinttypes>
|
||||
|
@ -34,6 +33,7 @@
|
|||
|
||||
static llama_context ** g_ctx;
|
||||
static llama_model ** g_model;
|
||||
static llama_sampling_context ** g_ctx_sampling;
|
||||
static gpt_params * g_params;
|
||||
static std::vector<llama_token> * g_input_tokens;
|
||||
static std::ostringstream * g_output_ss;
|
||||
|
@ -93,7 +93,7 @@ static void sigint_handler(int signo) {
|
|||
} else {
|
||||
console::cleanup();
|
||||
printf("\n");
|
||||
llama_print_timings(*g_ctx);
|
||||
llama_print_timings(*g_ctx, (*g_ctx_sampling)->smpl);
|
||||
write_logfile(*g_ctx, *g_params, *g_model, *g_input_tokens, g_output_ss->str(), *g_output_tokens);
|
||||
_exit(130);
|
||||
}
|
||||
|
@ -171,11 +171,13 @@ int main(int argc, char ** argv) {
|
|||
llama_backend_init();
|
||||
llama_numa_init(params.numa);
|
||||
|
||||
llama_model * model;
|
||||
llama_context * ctx;
|
||||
llama_model * model = nullptr;
|
||||
llama_context * ctx = nullptr;
|
||||
llama_sampling_context * ctx_sampling = nullptr;
|
||||
|
||||
g_model = &model;
|
||||
g_ctx = &ctx;
|
||||
g_ctx_sampling = &ctx_sampling;
|
||||
|
||||
// load the model and apply lora adapter, if any
|
||||
LOG("%s: load the model and apply lora adapter, if any\n", __func__);
|
||||
|
@ -349,7 +351,7 @@ int main(int argc, char ** argv) {
|
|||
|
||||
std::vector<llama_token> embd;
|
||||
|
||||
struct llama_sampling_context * ctx_sampling = llama_sampling_init(sparams);
|
||||
ctx_sampling = llama_sampling_init(sparams, llama_get_sampling(ctx));
|
||||
|
||||
while (n_remain != 0 || params.interactive) {
|
||||
// predict
|
||||
|
@ -423,7 +425,7 @@ int main(int argc, char ** argv) {
|
|||
if ((int) embd_inp.size() <= n_consumed && !is_interacting) {
|
||||
const llama_token id = llama_sampling_sample(ctx_sampling, ctx, nullptr);
|
||||
|
||||
llama_sampling_accept(ctx_sampling, ctx, id, true);
|
||||
llama_sampling_accept(ctx_sampling, id, true);
|
||||
|
||||
LOG("last: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, ctx_sampling->prev).c_str());
|
||||
|
||||
|
@ -444,7 +446,7 @@ int main(int argc, char ** argv) {
|
|||
|
||||
// push the prompt in the sampling context in order to apply repetition penalties later
|
||||
// for the prompt, we don't apply grammar rules
|
||||
llama_sampling_accept(ctx_sampling, ctx, embd_inp[n_consumed], false);
|
||||
llama_sampling_accept(ctx_sampling, embd_inp[n_consumed], false);
|
||||
|
||||
++n_consumed;
|
||||
if ((int) embd.size() >= params.n_batch) {
|
||||
|
@ -638,7 +640,7 @@ int main(int argc, char ** argv) {
|
|||
fflush(stdout);
|
||||
}
|
||||
|
||||
llama_print_timings(ctx);
|
||||
llama_print_timings(ctx, ctx_sampling->smpl);
|
||||
write_logfile(ctx, params, model, input_tokens, output_ss.str(), output_tokens);
|
||||
|
||||
llama_free(ctx);
|
||||
|
|
|
@ -1463,7 +1463,7 @@ int main(int argc, char ** argv) {
|
|||
fflush(p_err->fout);
|
||||
}
|
||||
|
||||
llama_print_timings(ctx);
|
||||
llama_print_timings(ctx, nullptr);
|
||||
|
||||
llama_free(ctx);
|
||||
}
|
||||
|
|
|
@ -385,6 +385,7 @@ Java_android_llama_cpp_LLamaAndroid_completion_1loop(
|
|||
jobject intvar_ncur
|
||||
) {
|
||||
const auto context = reinterpret_cast<llama_context *>(context_pointer);
|
||||
const auto sampling = reinterpret_cast<llama_sampling *>(llama_get_sampling(context));
|
||||
const auto batch = reinterpret_cast<llama_batch *>(batch_pointer);
|
||||
const auto model = llama_get_model(context);
|
||||
|
||||
|
@ -405,7 +406,7 @@ Java_android_llama_cpp_LLamaAndroid_completion_1loop(
|
|||
llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
|
||||
|
||||
// sample the most likely token
|
||||
const auto new_token_id = llama_sample_token_greedy(context, &candidates_p);
|
||||
const auto new_token_id = llama_sampling_sample_greedy(sampling, &candidates_p);
|
||||
|
||||
const auto n_cur = env->CallIntMethod(intvar_ncur, la_int_var_value);
|
||||
if (llama_token_is_eog(model, new_token_id) || n_cur == n_len) {
|
||||
|
|
|
@ -24,6 +24,7 @@ func llama_batch_add(_ batch: inout llama_batch, _ id: llama_token, _ pos: llama
|
|||
actor LlamaContext {
|
||||
private var model: OpaquePointer
|
||||
private var context: OpaquePointer
|
||||
private var sampling: OpaquePointer
|
||||
private var batch: llama_batch
|
||||
private var tokens_list: [llama_token]
|
||||
var is_done: Bool = false
|
||||
|
@ -42,12 +43,14 @@ actor LlamaContext {
|
|||
self.tokens_list = []
|
||||
self.batch = llama_batch_init(512, 0, 1)
|
||||
self.temporary_invalid_cchars = []
|
||||
self.sampling = llama_get_sampling(context)
|
||||
}
|
||||
|
||||
deinit {
|
||||
llama_batch_free(batch)
|
||||
llama_free(context)
|
||||
llama_free_model(model)
|
||||
llama_sampling_free(sampling)
|
||||
llama_backend_free()
|
||||
}
|
||||
|
||||
|
@ -156,7 +159,7 @@ actor LlamaContext {
|
|||
candidates.withUnsafeMutableBufferPointer() { buffer in
|
||||
var candidates_p = llama_token_data_array(data: buffer.baseAddress, size: buffer.count, sorted: false)
|
||||
|
||||
new_token_id = llama_sample_token_greedy(context, &candidates_p)
|
||||
new_token_id = llama_sampling_sample_greedy(sampling, &candidates_p)
|
||||
}
|
||||
|
||||
if llama_token_is_eog(model, new_token_id) || n_cur == n_len {
|
||||
|
|
|
@ -44,7 +44,7 @@ static const char * sample(struct llama_sampling_context * ctx_sampling,
|
|||
struct llama_context * ctx_llama,
|
||||
int * n_past) {
|
||||
const llama_token id = llama_sampling_sample(ctx_sampling, ctx_llama, NULL);
|
||||
llama_sampling_accept(ctx_sampling, ctx_llama, id, true);
|
||||
llama_sampling_accept(ctx_sampling, id, true);
|
||||
static std::string ret;
|
||||
if (llama_token_is_eog(llama_get_model(ctx_llama), id)) {
|
||||
ret = "</s>";
|
||||
|
@ -191,7 +191,7 @@ static void process_prompt(struct llava_context * ctx_llava, struct llava_image_
|
|||
|
||||
LOG_TEE("\n");
|
||||
|
||||
struct llama_sampling_context * ctx_sampling = llama_sampling_init(params->sparams);
|
||||
struct llama_sampling_context * ctx_sampling = llama_sampling_init(params->sparams, llama_get_sampling(ctx_llava->ctx_llama));
|
||||
if (!ctx_sampling) {
|
||||
fprintf(stderr, "%s: failed to initialize sampling subsystem\n", __func__);
|
||||
exit(1);
|
||||
|
@ -310,7 +310,7 @@ int main(int argc, char ** argv) {
|
|||
// process the prompt
|
||||
process_prompt(ctx_llava, image_embed, ¶ms, params.prompt);
|
||||
|
||||
llama_print_timings(ctx_llava->ctx_llama);
|
||||
llama_print_timings(ctx_llava->ctx_llama, nullptr);
|
||||
llava_image_embed_free(image_embed);
|
||||
ctx_llava->model = NULL;
|
||||
llava_free(ctx_llava);
|
||||
|
@ -327,7 +327,7 @@ int main(int argc, char ** argv) {
|
|||
// process the prompt
|
||||
process_prompt(ctx_llava, image_embed, ¶ms, params.prompt);
|
||||
|
||||
llama_print_timings(ctx_llava->ctx_llama);
|
||||
llama_print_timings(ctx_llava->ctx_llama, nullptr);
|
||||
llava_image_embed_free(image_embed);
|
||||
ctx_llava->model = NULL;
|
||||
llava_free(ctx_llava);
|
||||
|
|
|
@ -118,7 +118,7 @@ int main(int argc, char ** argv) {
|
|||
llama_batch batch = llama_batch_init(params.n_ctx, 0, W + G + 1);
|
||||
|
||||
// target model sampling context
|
||||
struct llama_sampling_context * ctx_sampling = llama_sampling_init(params.sparams);
|
||||
struct llama_sampling_context * ctx_sampling = llama_sampling_init(params.sparams, llama_get_sampling(ctx));
|
||||
|
||||
// verification n-grams
|
||||
std::vector<ngram_data> ngrams_cur(G);
|
||||
|
@ -161,7 +161,7 @@ int main(int argc, char ** argv) {
|
|||
{
|
||||
id = llama_sampling_sample(ctx_sampling, ctx, NULL, 0);
|
||||
|
||||
llama_sampling_accept(ctx_sampling, ctx, id, true);
|
||||
llama_sampling_accept(ctx_sampling, id, true);
|
||||
|
||||
{
|
||||
const std::string token_str = llama_token_to_piece(ctx, id);
|
||||
|
@ -286,7 +286,7 @@ int main(int argc, char ** argv) {
|
|||
// sample the next token
|
||||
id = llama_sampling_sample(ctx_sampling, ctx, NULL, i_batch);
|
||||
|
||||
llama_sampling_accept(ctx_sampling, ctx, id, true);
|
||||
llama_sampling_accept(ctx_sampling, id, true);
|
||||
|
||||
// print
|
||||
{
|
||||
|
@ -468,7 +468,7 @@ int main(int argc, char ** argv) {
|
|||
LOG_TEE("n_predict = %d\n", n_predict);
|
||||
LOG_TEE("n_accept = %d\n", n_accept);
|
||||
|
||||
llama_print_timings(ctx);
|
||||
llama_print_timings(ctx, ctx_sampling->smpl);
|
||||
|
||||
llama_kv_cache_view_free(&kvc_view);
|
||||
llama_sampling_free(ctx_sampling);
|
||||
|
|
|
@ -106,7 +106,7 @@ int main(int argc, char ** argv){
|
|||
|
||||
bool has_eos = false;
|
||||
|
||||
struct llama_sampling_context * ctx_sampling = llama_sampling_init(params.sparams);
|
||||
struct llama_sampling_context * ctx_sampling = llama_sampling_init(params.sparams, llama_get_sampling(ctx));
|
||||
|
||||
std::vector<llama_token> draft;
|
||||
|
||||
|
@ -132,7 +132,7 @@ int main(int argc, char ** argv){
|
|||
// sample from the target model
|
||||
llama_token id = llama_sampling_sample(ctx_sampling, ctx, NULL, i_dft);
|
||||
|
||||
llama_sampling_accept(ctx_sampling, ctx, id, true);
|
||||
llama_sampling_accept(ctx_sampling, id, true);
|
||||
|
||||
const std::string token_str = llama_token_to_piece(ctx, id);
|
||||
|
||||
|
@ -241,7 +241,7 @@ int main(int argc, char ** argv){
|
|||
LOG_TEE("accept = %.3f%%\n", 100.0f * n_accept / n_drafted);
|
||||
|
||||
LOG_TEE("\ntarget:\n");
|
||||
llama_print_timings(ctx);
|
||||
llama_print_timings(ctx, ctx_sampling->smpl);
|
||||
|
||||
llama_sampling_free(ctx_sampling);
|
||||
llama_batch_free(batch_tgt);
|
||||
|
|
|
@ -33,6 +33,7 @@
|
|||
|
||||
static llama_context ** g_ctx;
|
||||
static llama_model ** g_model;
|
||||
static llama_sampling_context ** g_ctx_sampling;
|
||||
static gpt_params * g_params;
|
||||
static std::vector<llama_token> * g_input_tokens;
|
||||
static std::ostringstream * g_output_ss;
|
||||
|
@ -105,7 +106,7 @@ static void sigint_handler(int signo) {
|
|||
} else {
|
||||
console::cleanup();
|
||||
printf("\n");
|
||||
llama_print_timings(*g_ctx);
|
||||
llama_print_timings(*g_ctx, (*g_ctx_sampling)->smpl);
|
||||
write_logfile(*g_ctx, *g_params, *g_model, *g_input_tokens, g_output_ss->str(), *g_output_tokens);
|
||||
_exit(130);
|
||||
}
|
||||
|
@ -121,8 +122,7 @@ static void llama_log_callback_logTee(ggml_log_level level, const char * text, v
|
|||
|
||||
static std::string chat_add_and_format(struct llama_model * model, std::vector<llama_chat_msg> & chat_msgs, std::string role, std::string content) {
|
||||
llama_chat_msg new_msg{role, content};
|
||||
auto formatted = llama_chat_format_single(
|
||||
model, g_params->chat_template, chat_msgs, new_msg, role == "user");
|
||||
auto formatted = llama_chat_format_single(model, g_params->chat_template, chat_msgs, new_msg, role == "user");
|
||||
chat_msgs.push_back({role, content});
|
||||
LOG("formatted: %s\n", formatted.c_str());
|
||||
return formatted;
|
||||
|
@ -198,12 +198,16 @@ int main(int argc, char ** argv) {
|
|||
llama_backend_init();
|
||||
llama_numa_init(params.numa);
|
||||
|
||||
llama_model * model;
|
||||
llama_context * ctx;
|
||||
llama_context * ctx_guidance = NULL;
|
||||
llama_model * model = nullptr;
|
||||
llama_context * ctx = nullptr;
|
||||
llama_context * ctx_guidance = nullptr;
|
||||
llama_sampling_context * ctx_sampling = nullptr;
|
||||
|
||||
std::vector<llama_chat_msg> chat_msgs;
|
||||
|
||||
g_model = &model;
|
||||
g_ctx = &ctx;
|
||||
g_ctx_sampling = &ctx_sampling;
|
||||
|
||||
// load the model and apply lora adapter, if any
|
||||
LOG("%s: load the model and apply lora adapter, if any\n", __func__);
|
||||
|
@ -531,7 +535,7 @@ int main(int argc, char ** argv) {
|
|||
antiprompt_ids.emplace_back(::llama_tokenize(ctx, antiprompt, false, true));
|
||||
}
|
||||
|
||||
struct llama_sampling_context * ctx_sampling = llama_sampling_init(sparams);
|
||||
ctx_sampling = llama_sampling_init(sparams, llama_get_sampling(ctx));
|
||||
if (!ctx_sampling) {
|
||||
fprintf(stderr, "%s: failed to initialize sampling subsystem\n", __func__);
|
||||
exit(1);
|
||||
|
@ -734,7 +738,7 @@ int main(int argc, char ** argv) {
|
|||
|
||||
const llama_token id = llama_sampling_sample(ctx_sampling, ctx, ctx_guidance);
|
||||
|
||||
llama_sampling_accept(ctx_sampling, ctx, id, /* apply_grammar= */ true);
|
||||
llama_sampling_accept(ctx_sampling, id, /* apply_grammar= */ true);
|
||||
|
||||
LOG("last: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, ctx_sampling->prev).c_str());
|
||||
|
||||
|
@ -755,7 +759,7 @@ int main(int argc, char ** argv) {
|
|||
|
||||
// push the prompt in the sampling context in order to apply repetition penalties later
|
||||
// for the prompt, we don't apply grammar rules
|
||||
llama_sampling_accept(ctx_sampling, ctx, embd_inp[n_consumed], /* apply_grammar= */ false);
|
||||
llama_sampling_accept(ctx_sampling, embd_inp[n_consumed], /* apply_grammar= */ false);
|
||||
|
||||
++n_consumed;
|
||||
if ((int) embd.size() >= params.n_batch) {
|
||||
|
@ -979,7 +983,7 @@ int main(int argc, char ** argv) {
|
|||
llama_state_save_file(ctx, path_session.c_str(), session_tokens.data(), session_tokens.size());
|
||||
}
|
||||
|
||||
llama_print_timings(ctx);
|
||||
llama_print_timings(ctx, ctx_sampling->smpl);
|
||||
write_logfile(ctx, params, model, input_tokens, output_ss.str(), output_tokens);
|
||||
|
||||
if (ctx_guidance) { llama_free(ctx_guidance); }
|
||||
|
|
|
@ -161,7 +161,7 @@ int main(int argc, char ** argv) {
|
|||
for (size_t i = 0; i < clients.size(); ++i) {
|
||||
auto & client = clients[i];
|
||||
client.id = i;
|
||||
client.ctx_sampling = llama_sampling_init(params.sparams);
|
||||
client.ctx_sampling = llama_sampling_init(params.sparams, model);
|
||||
}
|
||||
|
||||
std::vector<llama_token> tokens_system;
|
||||
|
@ -343,7 +343,7 @@ int main(int argc, char ** argv) {
|
|||
|
||||
const llama_token id = llama_sampling_sample(client.ctx_sampling, ctx, NULL, client.i_batch - i);
|
||||
|
||||
llama_sampling_accept(client.ctx_sampling, ctx, id, true);
|
||||
llama_sampling_accept(client.ctx_sampling, id, true);
|
||||
|
||||
if (client.n_decoded == 1) {
|
||||
// start measuring generation time after the first token to make sure all concurrent clients
|
||||
|
@ -413,7 +413,8 @@ int main(int argc, char ** argv) {
|
|||
|
||||
LOG_TEE("\n");
|
||||
|
||||
llama_print_timings(ctx);
|
||||
// TODO: print sampling/grammar timings for all clients
|
||||
llama_print_timings(ctx, nullptr);
|
||||
|
||||
llama_batch_free(batch);
|
||||
|
||||
|
|
|
@ -80,12 +80,13 @@ int main(int argc, char ** argv) {
|
|||
GGML_ASSERT(ctx_params.n_batch % n_grp == 0 && "n_batch must be divisible by n_grp");
|
||||
|
||||
llama_context * ctx = llama_new_context_with_model(model, ctx_params);
|
||||
|
||||
if (ctx == NULL) {
|
||||
fprintf(stderr , "%s: error: failed to create the llama_context\n" , __func__);
|
||||
return 1;
|
||||
}
|
||||
|
||||
llama_sampling * smpl = llama_get_sampling(ctx);
|
||||
|
||||
// tokenize the prompt
|
||||
std::vector<llama_token> tokens_list;
|
||||
tokens_list = ::llama_tokenize(ctx, params.prompt, true);
|
||||
|
@ -230,7 +231,7 @@ int main(int argc, char ** argv) {
|
|||
llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
|
||||
|
||||
// sample the most likely token
|
||||
const llama_token new_token_id = llama_sample_token_greedy(ctx, &candidates_p);
|
||||
const llama_token new_token_id = llama_sampling_sample_greedy(smpl, &candidates_p);
|
||||
|
||||
// is it an end of generation?
|
||||
if (llama_token_is_eog(model, new_token_id) || n_cur == n_len) {
|
||||
|
@ -267,7 +268,7 @@ int main(int argc, char ** argv) {
|
|||
LOG_TEE("%s: decoded %d tokens in %.2f s, speed: %.2f t/s\n",
|
||||
__func__, n_decode, (t_main_end - t_main_start) / 1000000.0f, n_decode / ((t_main_end - t_main_start) / 1000000.0f));
|
||||
|
||||
llama_print_timings(ctx);
|
||||
llama_print_timings(ctx, nullptr);
|
||||
|
||||
fprintf(stderr, "\n");
|
||||
|
||||
|
|
|
@ -2054,7 +2054,7 @@ int main(int argc, char ** argv) {
|
|||
results = perplexity(ctx, params, n_ctx);
|
||||
}
|
||||
|
||||
llama_print_timings(ctx);
|
||||
llama_print_timings(ctx, nullptr);
|
||||
write_logfile(ctx, params, model, results);
|
||||
|
||||
llama_free(ctx);
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
#define LLAMA_API_INTERNAL
|
||||
#include "common.h"
|
||||
#include "ggml.h"
|
||||
#include "llama.h"
|
||||
#include "llama-impl.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <cassert>
|
||||
|
|
|
@ -293,7 +293,7 @@ int main(int argc, char ** argv) {
|
|||
}
|
||||
|
||||
// clean up
|
||||
llama_print_timings(ctx);
|
||||
llama_print_timings(ctx, nullptr);
|
||||
llama_free(ctx);
|
||||
llama_free_model(model);
|
||||
llama_backend_free();
|
||||
|
|
|
@ -38,6 +38,8 @@ int main(int argc, char ** argv) {
|
|||
return 1;
|
||||
}
|
||||
|
||||
llama_sampling * smpl = llama_get_sampling(ctx);
|
||||
|
||||
// tokenize prompt
|
||||
auto tokens = llama_tokenize(ctx, params.prompt, true);
|
||||
|
||||
|
@ -73,7 +75,7 @@ int main(int argc, char ** argv) {
|
|||
candidates.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f});
|
||||
}
|
||||
llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
|
||||
auto next_token = llama_sample_token(ctx, &candidates_p);
|
||||
auto next_token = llama_sampling_sample(smpl, &candidates_p);
|
||||
auto next_token_str = llama_token_to_piece(ctx, next_token);
|
||||
|
||||
printf("%s", next_token_str.c_str());
|
||||
|
@ -96,6 +98,8 @@ int main(int argc, char ** argv) {
|
|||
// make new context
|
||||
auto * ctx2 = llama_new_context_with_model(model, llama_context_params_from_gpt_params(params));
|
||||
|
||||
llama_sampling * smpl2 = llama_get_sampling(ctx2);
|
||||
|
||||
printf("\nsecond run: %s", params.prompt.c_str());
|
||||
|
||||
// load state (rng, logits, embedding and kv_cache) from file
|
||||
|
@ -132,7 +136,7 @@ int main(int argc, char ** argv) {
|
|||
candidates.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f});
|
||||
}
|
||||
llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
|
||||
auto next_token = llama_sample_token(ctx2, &candidates_p);
|
||||
auto next_token = llama_sampling_sample(smpl2, &candidates_p);
|
||||
auto next_token_str = llama_token_to_piece(ctx2, next_token);
|
||||
|
||||
printf("%s", next_token_str.c_str());
|
||||
|
@ -157,7 +161,9 @@ int main(int argc, char ** argv) {
|
|||
}
|
||||
|
||||
// make new context
|
||||
auto* ctx3 = llama_new_context_with_model(model, llama_context_params_from_gpt_params(params));
|
||||
auto * ctx3 = llama_new_context_with_model(model, llama_context_params_from_gpt_params(params));
|
||||
|
||||
llama_sampling * smpl3 = llama_get_sampling(ctx3);
|
||||
|
||||
printf("\nsingle seq run: %s", params.prompt.c_str());
|
||||
|
||||
|
@ -223,7 +229,7 @@ int main(int argc, char ** argv) {
|
|||
candidates.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f});
|
||||
}
|
||||
llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
|
||||
auto next_token = llama_sample_token(ctx3, &candidates_p);
|
||||
auto next_token = llama_sampling_sample(smpl3, &candidates_p);
|
||||
auto next_token_str = llama_token_to_piece(ctx3, next_token);
|
||||
|
||||
printf("%s", next_token_str.c_str());
|
||||
|
|
|
@ -3,7 +3,6 @@
|
|||
#include "common.h"
|
||||
#include "json-schema-to-grammar.h"
|
||||
#include "llama.h"
|
||||
#include "grammar-parser.h"
|
||||
|
||||
#ifndef NDEBUG
|
||||
// crash the server in debug mode, otherwise send an http 500 error
|
||||
|
@ -1099,7 +1098,8 @@ struct server_context {
|
|||
if (slot.ctx_sampling != nullptr) {
|
||||
llama_sampling_free(slot.ctx_sampling);
|
||||
}
|
||||
slot.ctx_sampling = llama_sampling_init(slot.sparams);
|
||||
|
||||
slot.ctx_sampling = llama_sampling_init(slot.sparams, model);
|
||||
if (slot.ctx_sampling == nullptr) {
|
||||
// for now, the only error that may happen here is invalid grammar
|
||||
send_error(task, "Failed to parse grammar", ERROR_TYPE_INVALID_REQUEST);
|
||||
|
@ -2169,7 +2169,7 @@ struct server_context {
|
|||
|
||||
// push the prompt into the sampling context (do not apply grammar)
|
||||
for (int i = 0; i < slot.n_past; ++i) {
|
||||
llama_sampling_accept(slot.ctx_sampling, ctx, slot.cache_tokens[i], false);
|
||||
llama_sampling_accept(slot.ctx_sampling, slot.cache_tokens[i], false);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -2401,7 +2401,7 @@ struct server_context {
|
|||
completion_token_output result;
|
||||
const llama_token id = llama_sampling_sample(slot.ctx_sampling, ctx, NULL, slot.i_batch - i);
|
||||
|
||||
llama_sampling_accept(slot.ctx_sampling, ctx, id, true);
|
||||
llama_sampling_accept(slot.ctx_sampling, id, true);
|
||||
|
||||
slot.n_decoded += 1;
|
||||
if (slot.n_decoded == 1) {
|
||||
|
@ -2419,7 +2419,7 @@ struct server_context {
|
|||
|
||||
// Make sure at least n_probs top tokens are at the front of the vector:
|
||||
if (slot.sparams.temp == 0.0f && n_probs > n_valid) {
|
||||
llama_sample_top_k(ctx, &cur_p, n_probs, 0);
|
||||
llama_sampling_top_k(slot.ctx_sampling->smpl, &cur_p, n_probs, 0);
|
||||
}
|
||||
|
||||
if (slot.sparams.temp == 0.0f) {
|
||||
|
|
|
@ -55,6 +55,8 @@ int main(int argc, char ** argv) {
|
|||
return 1;
|
||||
}
|
||||
|
||||
llama_sampling * smpl = llama_get_sampling(ctx);
|
||||
|
||||
// tokenize the prompt
|
||||
|
||||
std::vector<llama_token> tokens_list;
|
||||
|
@ -123,7 +125,7 @@ int main(int argc, char ** argv) {
|
|||
llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
|
||||
|
||||
// sample the most likely token
|
||||
const llama_token new_token_id = llama_sample_token_greedy(ctx, &candidates_p);
|
||||
const llama_token new_token_id = llama_sampling_sample_greedy(smpl, &candidates_p);
|
||||
|
||||
// is it an end of generation?
|
||||
if (llama_token_is_eog(model, new_token_id) || n_cur == n_predict) {
|
||||
|
@ -160,7 +162,7 @@ int main(int argc, char ** argv) {
|
|||
LOG_TEE("%s: decoded %d tokens in %.2f s, speed: %.2f t/s\n",
|
||||
__func__, n_decode, (t_main_end - t_main_start) / 1000000.0f, n_decode / ((t_main_end - t_main_start) / 1000000.0f));
|
||||
|
||||
llama_print_timings(ctx);
|
||||
llama_print_timings(ctx, nullptr);
|
||||
|
||||
fprintf(stderr, "\n");
|
||||
|
||||
|
|
|
@ -178,8 +178,8 @@ int main(int argc, char ** argv) {
|
|||
// used to determine end of generation
|
||||
bool has_eos = false;
|
||||
|
||||
// target model sampling context
|
||||
struct llama_sampling_context * ctx_sampling = llama_sampling_init(params.sparams);
|
||||
// target model sampling context (reuse the llama_context's sampling instance)
|
||||
struct llama_sampling_context * ctx_sampling = llama_sampling_init(params.sparams, llama_get_sampling(ctx_tgt));
|
||||
|
||||
// draft sequence data
|
||||
std::vector<seq_draft> drafts(n_seq_dft);
|
||||
|
@ -190,7 +190,8 @@ int main(int argc, char ** argv) {
|
|||
}
|
||||
|
||||
for (int s = 0; s < n_seq_dft; ++s) {
|
||||
drafts[s].ctx_sampling = llama_sampling_init(params.sparams);
|
||||
// allocate llama_sampling for each draft sequence
|
||||
drafts[s].ctx_sampling = llama_sampling_init(params.sparams, model_dft);
|
||||
}
|
||||
|
||||
llama_batch batch_dft = llama_batch_init(params.n_ctx, 0, 1);
|
||||
|
@ -234,8 +235,10 @@ int main(int argc, char ** argv) {
|
|||
// stochastic verification
|
||||
|
||||
llama_token_data_array dist_tgt = llama_sampling_prepare(ctx_sampling, ctx_tgt, NULL, drafts[s_keep].i_batch_tgt[i_dft], true, NULL);
|
||||
llama_sample_softmax(ctx_tgt, &dist_tgt);
|
||||
float p_tgt = 0, p_dft = 0;
|
||||
llama_sampling_softmax(ctx_sampling->smpl, &dist_tgt);
|
||||
|
||||
float p_tgt = 0.0f;
|
||||
float p_dft = 0.0f;
|
||||
|
||||
// GGML_ASSERT(dist_tgt.size() == dist_dft.size());
|
||||
|
||||
|
@ -277,7 +280,7 @@ int main(int argc, char ** argv) {
|
|||
accept = true;
|
||||
token_id = drafts[s].tokens[i_dft];
|
||||
token_str = llama_token_to_piece(ctx_tgt, token_id);
|
||||
llama_sampling_accept(ctx_sampling, ctx_tgt, token_id, true);
|
||||
llama_sampling_accept(ctx_sampling, token_id, true);
|
||||
|
||||
LOG("draft token %d of sequence %d (%d, '%s') accepted\n", i_dft, s, token_id, token_str.c_str());
|
||||
break;
|
||||
|
@ -331,8 +334,8 @@ int main(int argc, char ** argv) {
|
|||
// all drafted tokens were rejected
|
||||
// sample from the target model
|
||||
LOG("all drafted tokens were rejected, sampling from residual distribution\n");
|
||||
token_id = llama_sample_token(ctx_tgt, &dist_tgt);
|
||||
llama_sampling_accept(ctx_sampling, ctx_tgt, token_id, true);
|
||||
token_id = llama_sampling_sample(ctx_sampling->smpl, &dist_tgt);
|
||||
llama_sampling_accept(ctx_sampling, token_id, true);
|
||||
token_str = llama_token_to_piece(ctx_tgt, token_id);
|
||||
}
|
||||
|
||||
|
@ -343,7 +346,7 @@ int main(int argc, char ** argv) {
|
|||
LOG("sampling target: s_keep = %3d, i_dft = %3d, i_batch_tgt = %3d\n", s_keep, i_dft, drafts[s_keep].i_batch_tgt[i_dft]);
|
||||
token_id = llama_sampling_sample(ctx_sampling, ctx_tgt, NULL, drafts[s_keep].i_batch_tgt[i_dft]);
|
||||
|
||||
llama_sampling_accept(ctx_sampling, ctx_tgt, token_id, true);
|
||||
llama_sampling_accept(ctx_sampling, token_id, true);
|
||||
|
||||
//LOG("last: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx_tgt, ctx_sampling->prev).c_str());
|
||||
|
||||
|
@ -518,7 +521,7 @@ int main(int argc, char ** argv) {
|
|||
|
||||
const int s = sa[is];
|
||||
|
||||
llama_sampling_accept(drafts[s].ctx_sampling, ctx_dft, id, true);
|
||||
llama_sampling_accept(drafts[s].ctx_sampling, id, true);
|
||||
|
||||
drafts[s].tokens.push_back(id);
|
||||
// save cur_p.data into drafts[s].dists
|
||||
|
@ -593,10 +596,11 @@ int main(int argc, char ** argv) {
|
|||
LOG_TEE("accept = %.3f%%\n", 100.0f * n_accept / n_drafted);
|
||||
|
||||
LOG_TEE("\ndraft:\n");
|
||||
llama_print_timings(ctx_dft);
|
||||
// TODO: print sampling/grammar timings for all drafts
|
||||
llama_print_timings(ctx_dft, nullptr);
|
||||
|
||||
LOG_TEE("\ntarget:\n");
|
||||
llama_print_timings(ctx_tgt);
|
||||
llama_print_timings(ctx_tgt, ctx_sampling->smpl);
|
||||
|
||||
llama_sampling_free(ctx_sampling);
|
||||
for (int s = 0; s < n_seq_dft; ++s) {
|
||||
|
|
311
include/llama.h
311
include/llama.h
|
@ -53,6 +53,7 @@ extern "C" {
|
|||
// TODO: show sample usage
|
||||
//
|
||||
|
||||
// struct llama_vocab; // TODO: add in the future
|
||||
struct llama_model;
|
||||
struct llama_context;
|
||||
|
||||
|
@ -355,53 +356,22 @@ extern "C" {
|
|||
void * kv_overrides; // pointer to vector containing overrides
|
||||
} llama_model_quantize_params;
|
||||
|
||||
// grammar types
|
||||
struct llama_grammar;
|
||||
|
||||
// grammar element type
|
||||
enum llama_gretype {
|
||||
// end of rule definition
|
||||
LLAMA_GRETYPE_END = 0,
|
||||
|
||||
// start of alternate definition for rule
|
||||
LLAMA_GRETYPE_ALT = 1,
|
||||
|
||||
// non-terminal element: reference to rule
|
||||
LLAMA_GRETYPE_RULE_REF = 2,
|
||||
|
||||
// terminal element: character (code point)
|
||||
LLAMA_GRETYPE_CHAR = 3,
|
||||
|
||||
// inverse char(s) ([^a], [^a-b] [^abc])
|
||||
LLAMA_GRETYPE_CHAR_NOT = 4,
|
||||
|
||||
// modifies a preceding LLAMA_GRETYPE_CHAR or LLAMA_GRETYPE_CHAR_ALT to
|
||||
// be an inclusive range ([a-z])
|
||||
LLAMA_GRETYPE_CHAR_RNG_UPPER = 5,
|
||||
|
||||
// modifies a preceding LLAMA_GRETYPE_CHAR or
|
||||
// LLAMA_GRETYPE_CHAR_RNG_UPPER to add an alternate char to match ([ab], [a-zA])
|
||||
LLAMA_GRETYPE_CHAR_ALT = 6,
|
||||
|
||||
// any character (.)
|
||||
LLAMA_GRETYPE_CHAR_ANY = 7,
|
||||
};
|
||||
|
||||
typedef struct llama_grammar_element {
|
||||
enum llama_gretype type;
|
||||
uint32_t value; // Unicode code point or rule ID
|
||||
} llama_grammar_element;
|
||||
// sampling types
|
||||
struct llama_sampling;
|
||||
|
||||
// performance timing information
|
||||
struct llama_timings {
|
||||
double t_start_ms;
|
||||
double t_end_ms;
|
||||
double t_load_ms;
|
||||
double t_sample_ms;
|
||||
double t_sampling_ms;
|
||||
double t_grammar_ms;
|
||||
double t_p_eval_ms;
|
||||
double t_eval_ms;
|
||||
|
||||
int32_t n_sample;
|
||||
int32_t n_sampling;
|
||||
int32_t n_grammar_sample;
|
||||
int32_t n_grammar_accept;
|
||||
int32_t n_p_eval;
|
||||
int32_t n_eval;
|
||||
};
|
||||
|
@ -452,23 +422,23 @@ extern "C" {
|
|||
LLAMA_API bool llama_supports_mlock (void);
|
||||
LLAMA_API bool llama_supports_gpu_offload(void);
|
||||
|
||||
LLAMA_API const struct llama_model * llama_get_model(const struct llama_context * ctx);
|
||||
|
||||
LLAMA_API uint32_t llama_n_ctx (const struct llama_context * ctx);
|
||||
LLAMA_API uint32_t llama_n_batch (const struct llama_context * ctx);
|
||||
LLAMA_API uint32_t llama_n_ubatch (const struct llama_context * ctx);
|
||||
LLAMA_API uint32_t llama_n_seq_max (const struct llama_context * ctx);
|
||||
|
||||
LLAMA_API enum llama_pooling_type llama_pooling_type(const struct llama_context * ctx);
|
||||
|
||||
LLAMA_API enum llama_vocab_type llama_vocab_type (const struct llama_model * model);
|
||||
LLAMA_API enum llama_rope_type llama_rope_type (const struct llama_model * model);
|
||||
|
||||
LLAMA_API int32_t llama_n_vocab (const struct llama_model * model);
|
||||
LLAMA_API int32_t llama_n_ctx_train(const struct llama_model * model);
|
||||
LLAMA_API int32_t llama_n_embd (const struct llama_model * model);
|
||||
LLAMA_API int32_t llama_n_layer (const struct llama_model * model);
|
||||
|
||||
LLAMA_API const struct llama_model * llama_get_model (const struct llama_context * ctx);
|
||||
LLAMA_API struct llama_sampling * llama_get_sampling( struct llama_context * ctx);
|
||||
|
||||
LLAMA_API enum llama_pooling_type llama_pooling_type(const struct llama_context * ctx);
|
||||
LLAMA_API enum llama_vocab_type llama_vocab_type (const struct llama_model * model);
|
||||
LLAMA_API enum llama_rope_type llama_rope_type (const struct llama_model * model);
|
||||
|
||||
// Get the model's RoPE frequency scaling factor
|
||||
LLAMA_API float llama_rope_freq_scale_train(const struct llama_model * model);
|
||||
|
||||
|
@ -998,53 +968,85 @@ extern "C" {
|
|||
char * buf,
|
||||
int32_t length);
|
||||
|
||||
//
|
||||
// Grammar
|
||||
//
|
||||
|
||||
/// Initialize a llama_grammar.
|
||||
///
|
||||
/// @param rules The rule elements of the grammar to initialize.
|
||||
/// @param n_rules The number of rules.
|
||||
/// @param start_rule_index The index of the root rule (the starting point of the grammar).
|
||||
/// @return The initialized llama_grammar or nullptr if initialization failed.
|
||||
LLAMA_API struct llama_grammar * llama_grammar_init(
|
||||
const llama_grammar_element ** rules,
|
||||
size_t n_rules,
|
||||
size_t start_rule_index);
|
||||
|
||||
LLAMA_API void llama_grammar_free(struct llama_grammar * grammar);
|
||||
|
||||
LLAMA_API struct llama_grammar * llama_grammar_copy(const struct llama_grammar * grammar);
|
||||
|
||||
/// @details Apply constraints from grammar
|
||||
LLAMA_API void llama_grammar_sample(
|
||||
const struct llama_grammar * grammar,
|
||||
const struct llama_context * ctx,
|
||||
llama_token_data_array * candidates);
|
||||
LLAMA_API DEPRECATED(void llama_sample_grammar(
|
||||
struct llama_context * ctx,
|
||||
llama_token_data_array * candidates,
|
||||
const struct llama_grammar * grammar),
|
||||
"use llama_grammar_sample instead");
|
||||
|
||||
/// @details Accepts the sampled token into the grammar
|
||||
LLAMA_API void llama_grammar_accept_token(
|
||||
struct llama_grammar * grammar,
|
||||
struct llama_context * ctx,
|
||||
llama_token token);
|
||||
|
||||
//
|
||||
// Sampling functions
|
||||
//
|
||||
|
||||
// TODO: args become llama_sampling_params
|
||||
// TODO: llama_model should become llama_vocab
|
||||
LLAMA_API struct llama_sampling * llama_sampling_init(const struct llama_model * model, const char * grammar_str, const char * grammar_root);
|
||||
|
||||
LLAMA_API void llama_sampling_free(struct llama_sampling * smpl);
|
||||
|
||||
LLAMA_API struct llama_sampling * llama_sampling_cp(const struct llama_sampling * smpl);
|
||||
|
||||
LLAMA_API void llama_sampling_reset(struct llama_sampling * smpl, const char * grammar_str, const char * grammar_root);
|
||||
|
||||
// Sets the current rng seed.
|
||||
LLAMA_API void llama_set_rng_seed(struct llama_context * ctx, uint32_t seed);
|
||||
LLAMA_API void llama_sampling_set_rng_seed(struct llama_sampling * smpl, uint32_t seed);
|
||||
|
||||
/// @details Sorts candidate tokens by their logits in descending order and calculate probabilities based on logits.
|
||||
LLAMA_API void llama_sampling_softmax(
|
||||
struct llama_sampling * smpl,
|
||||
llama_token_data_array * candidates);
|
||||
|
||||
/// @details Top-K sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751
|
||||
LLAMA_API void llama_sampling_top_k(
|
||||
struct llama_sampling * smpl,
|
||||
llama_token_data_array * candidates,
|
||||
int32_t k,
|
||||
size_t min_keep);
|
||||
|
||||
/// @details Nucleus sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751
|
||||
LLAMA_API void llama_sampling_top_p(
|
||||
struct llama_sampling * smpl,
|
||||
llama_token_data_array * candidates,
|
||||
float p,
|
||||
size_t min_keep);
|
||||
|
||||
/// @details Minimum P sampling as described in https://github.com/ggerganov/llama.cpp/pull/3841
|
||||
LLAMA_API void llama_sampling_min_p(
|
||||
struct llama_sampling * smpl,
|
||||
llama_token_data_array * candidates,
|
||||
float p,
|
||||
size_t min_keep);
|
||||
|
||||
/// @details Tail Free Sampling described in https://www.trentonbricken.com/Tail-Free-Sampling/.
|
||||
LLAMA_API void llama_sampling_tail_free(
|
||||
struct llama_sampling * smpl,
|
||||
llama_token_data_array * candidates,
|
||||
float z,
|
||||
size_t min_keep);
|
||||
|
||||
/// @details Locally Typical Sampling implementation described in the paper https://arxiv.org/abs/2202.00666.
|
||||
LLAMA_API void llama_sampling_typical(
|
||||
struct llama_sampling * smpl,
|
||||
llama_token_data_array * candidates,
|
||||
float p,
|
||||
size_t min_keep);
|
||||
|
||||
/// @details Dynamic temperature implementation described in the paper https://arxiv.org/abs/2309.02772.
|
||||
LLAMA_API void llama_sampling_entropy(
|
||||
struct llama_sampling * smpl,
|
||||
llama_token_data_array * candidates_p,
|
||||
float min_temp,
|
||||
float max_temp,
|
||||
float exponent_val);
|
||||
|
||||
LLAMA_API void llama_sampling_temp(
|
||||
struct llama_sampling * smpl,
|
||||
llama_token_data_array * candidates,
|
||||
float temp);
|
||||
|
||||
/// @details Apply constraints from grammar
|
||||
LLAMA_API void llama_sampling_grammar(
|
||||
struct llama_sampling * smpl,
|
||||
llama_token_data_array * candidates);
|
||||
|
||||
/// @details Repetition penalty described in CTRL academic paper https://arxiv.org/abs/1909.05858, with negative logit fix.
|
||||
/// @details Frequency and presence penalties described in OpenAI API https://platform.openai.com/docs/api-reference/parameter-details.
|
||||
LLAMA_API void llama_sample_repetition_penalties(
|
||||
struct llama_context * ctx,
|
||||
LLAMA_API void llama_sampling_repetition_penalties(
|
||||
struct llama_sampling * smpl,
|
||||
llama_token_data_array * candidates,
|
||||
const llama_token * last_tokens,
|
||||
size_t penalty_last_n,
|
||||
|
@ -1056,73 +1058,20 @@ extern "C" {
|
|||
/// @param logits Logits extracted from the original generation context.
|
||||
/// @param logits_guidance Logits extracted from a separate context from the same model. Other than a negative prompt at the beginning, it should have all generated and user input tokens copied from the main context.
|
||||
/// @param scale Guidance strength. 1.0f means no guidance. Higher values mean stronger guidance.
|
||||
LLAMA_API void llama_sample_apply_guidance(
|
||||
struct llama_context * ctx,
|
||||
LLAMA_API void llama_sampling_apply_guidance(
|
||||
struct llama_sampling * smpl,
|
||||
float * logits,
|
||||
float * logits_guidance,
|
||||
float scale);
|
||||
|
||||
/// @details Sorts candidate tokens by their logits in descending order and calculate probabilities based on logits.
|
||||
LLAMA_API void llama_sample_softmax(
|
||||
struct llama_context * ctx,
|
||||
llama_token_data_array * candidates);
|
||||
|
||||
/// @details Top-K sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751
|
||||
LLAMA_API void llama_sample_top_k(
|
||||
struct llama_context * ctx,
|
||||
llama_token_data_array * candidates,
|
||||
int32_t k,
|
||||
size_t min_keep);
|
||||
|
||||
/// @details Nucleus sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751
|
||||
LLAMA_API void llama_sample_top_p(
|
||||
struct llama_context * ctx,
|
||||
llama_token_data_array * candidates,
|
||||
float p,
|
||||
size_t min_keep);
|
||||
|
||||
/// @details Minimum P sampling as described in https://github.com/ggerganov/llama.cpp/pull/3841
|
||||
LLAMA_API void llama_sample_min_p(
|
||||
struct llama_context * ctx,
|
||||
llama_token_data_array * candidates,
|
||||
float p,
|
||||
size_t min_keep);
|
||||
|
||||
/// @details Tail Free Sampling described in https://www.trentonbricken.com/Tail-Free-Sampling/.
|
||||
LLAMA_API void llama_sample_tail_free(
|
||||
struct llama_context * ctx,
|
||||
llama_token_data_array * candidates,
|
||||
float z,
|
||||
size_t min_keep);
|
||||
|
||||
/// @details Locally Typical Sampling implementation described in the paper https://arxiv.org/abs/2202.00666.
|
||||
LLAMA_API void llama_sample_typical(
|
||||
struct llama_context * ctx,
|
||||
llama_token_data_array * candidates,
|
||||
float p,
|
||||
size_t min_keep);
|
||||
|
||||
/// @details Dynamic temperature implementation described in the paper https://arxiv.org/abs/2309.02772.
|
||||
LLAMA_API void llama_sample_entropy(
|
||||
struct llama_context * ctx,
|
||||
llama_token_data_array * candidates_p,
|
||||
float min_temp,
|
||||
float max_temp,
|
||||
float exponent_val);
|
||||
|
||||
LLAMA_API void llama_sample_temp(
|
||||
struct llama_context * ctx,
|
||||
llama_token_data_array * candidates,
|
||||
float temp);
|
||||
|
||||
/// @details Mirostat 1.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words.
|
||||
/// @param candidates A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text.
|
||||
/// @param tau The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text.
|
||||
/// @param eta The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates.
|
||||
/// @param m The number of tokens considered in the estimation of `s_hat`. This is an arbitrary value that is used to calculate `s_hat`, which in turn helps to calculate the value of `k`. In the paper, they use `m = 100`, but you can experiment with different values to see how it affects the performance of the algorithm.
|
||||
/// @param mu Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal.
|
||||
LLAMA_API llama_token llama_sample_token_mirostat(
|
||||
struct llama_context * ctx,
|
||||
LLAMA_API llama_token llama_sampling_sample_mirostat(
|
||||
struct llama_sampling * smpl,
|
||||
llama_token_data_array * candidates,
|
||||
float tau,
|
||||
float eta,
|
||||
|
@ -1134,24 +1083,29 @@ extern "C" {
|
|||
/// @param tau The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text.
|
||||
/// @param eta The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates.
|
||||
/// @param mu Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal.
|
||||
LLAMA_API llama_token llama_sample_token_mirostat_v2(
|
||||
struct llama_context * ctx,
|
||||
LLAMA_API llama_token llama_sampling_sample_mirostat_v2(
|
||||
struct llama_sampling * smpl,
|
||||
llama_token_data_array * candidates,
|
||||
float tau,
|
||||
float eta,
|
||||
float * mu);
|
||||
|
||||
/// @details Selects the token with the highest probability.
|
||||
/// Does not compute the token probabilities. Use llama_sample_softmax() instead.
|
||||
LLAMA_API llama_token llama_sample_token_greedy(
|
||||
struct llama_context * ctx,
|
||||
/// Does not compute the token probabilities. Use llama_sampling_softmax() instead.
|
||||
LLAMA_API llama_token llama_sampling_sample_greedy(
|
||||
struct llama_sampling * smpl,
|
||||
llama_token_data_array * candidates);
|
||||
|
||||
/// @details Randomly selects a token from the candidates based on their probabilities using the RNG of ctx.
|
||||
LLAMA_API llama_token llama_sample_token(
|
||||
struct llama_context * ctx,
|
||||
/// @details Randomly selects a token from the candidates based on their probabilities
|
||||
LLAMA_API llama_token llama_sampling_sample(
|
||||
struct llama_sampling * smpl,
|
||||
llama_token_data_array * candidates);
|
||||
|
||||
/// @details Accepts the sampled token into the grammar
|
||||
LLAMA_API void llama_sampling_accept(
|
||||
struct llama_sampling * smpl,
|
||||
llama_token token);
|
||||
|
||||
//
|
||||
// Model split
|
||||
//
|
||||
|
@ -1169,8 +1123,8 @@ extern "C" {
|
|||
// Performance information
|
||||
LLAMA_API struct llama_timings llama_get_timings(struct llama_context * ctx);
|
||||
|
||||
LLAMA_API void llama_print_timings(struct llama_context * ctx);
|
||||
LLAMA_API void llama_reset_timings(struct llama_context * ctx);
|
||||
LLAMA_API void llama_print_timings(struct llama_context * ctx, struct llama_sampling * smpl);
|
||||
LLAMA_API void llama_reset_timings(struct llama_context * ctx, struct llama_sampling * smpl);
|
||||
|
||||
// Print system information
|
||||
LLAMA_API const char * llama_print_system_info(void);
|
||||
|
@ -1185,59 +1139,4 @@ extern "C" {
|
|||
}
|
||||
#endif
|
||||
|
||||
// Internal API to be implemented by llama.cpp and used by tests/benchmarks only
|
||||
#ifdef LLAMA_API_INTERNAL
|
||||
|
||||
#include <random>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
struct ggml_tensor;
|
||||
|
||||
const std::vector<std::pair<std::string, struct ggml_tensor *>> & llama_internal_get_tensor_map(
|
||||
struct llama_context * ctx
|
||||
);
|
||||
|
||||
struct llama_partial_utf8 {
|
||||
uint32_t value; // bit value so far (unshifted)
|
||||
int n_remain; // num bytes remaining; -1 indicates invalid sequence
|
||||
};
|
||||
|
||||
struct llama_grammar_candidate {
|
||||
size_t index;
|
||||
const uint32_t * code_points;
|
||||
llama_partial_utf8 partial_utf8;
|
||||
};
|
||||
|
||||
using llama_grammar_rule = std::vector< llama_grammar_element>;
|
||||
using llama_grammar_stack = std::vector<const llama_grammar_element *>;
|
||||
|
||||
using llama_grammar_rules = std::vector<llama_grammar_rule>;
|
||||
using llama_grammar_stacks = std::vector<llama_grammar_stack>;
|
||||
using llama_grammar_candidates = std::vector<llama_grammar_candidate>;
|
||||
|
||||
const llama_grammar_rules & llama_grammar_get_rules (const struct llama_grammar * grammar);
|
||||
llama_grammar_stacks & llama_grammar_get_stacks( struct llama_grammar * grammar);
|
||||
|
||||
void llama_grammar_accept(
|
||||
const llama_grammar_rules & rules,
|
||||
const llama_grammar_stacks & stacks,
|
||||
const uint32_t chr,
|
||||
llama_grammar_stacks & new_stacks);
|
||||
|
||||
std::vector<llama_grammar_candidate> llama_grammar_reject_candidates_for_stack(
|
||||
const llama_grammar_rules & rules,
|
||||
const llama_grammar_stack & stack,
|
||||
const llama_grammar_candidates & candidates);
|
||||
|
||||
std::pair<std::vector<uint32_t>, llama_partial_utf8> decode_utf8(
|
||||
const std::string & src,
|
||||
llama_partial_utf8 partial_start);
|
||||
|
||||
// Randomly selects a token from the candidates based on their probabilities using given std::mt19937.
|
||||
// This is a temporary workaround in order to fix race conditions when sampling with multiple sequences.
|
||||
llama_token llama_sample_token_with_rng(struct llama_context * ctx, llama_token_data_array * candidates, std::mt19937 & rng);
|
||||
|
||||
#endif // LLAMA_API_INTERNAL
|
||||
|
||||
#endif // LLAMA_H
|
||||
|
|
File diff suppressed because it is too large
Load diff
|
@ -2,38 +2,153 @@
|
|||
|
||||
#include "llama-impl.h"
|
||||
|
||||
#include <map>
|
||||
|
||||
struct llama_vocab;
|
||||
struct llama_sampling;
|
||||
|
||||
// grammar element type
|
||||
enum llama_gretype {
|
||||
// end of rule definition
|
||||
LLAMA_GRETYPE_END = 0,
|
||||
|
||||
// start of alternate definition for rule
|
||||
LLAMA_GRETYPE_ALT = 1,
|
||||
|
||||
// non-terminal element: reference to rule
|
||||
LLAMA_GRETYPE_RULE_REF = 2,
|
||||
|
||||
// terminal element: character (code point)
|
||||
LLAMA_GRETYPE_CHAR = 3,
|
||||
|
||||
// inverse char(s) ([^a], [^a-b] [^abc])
|
||||
LLAMA_GRETYPE_CHAR_NOT = 4,
|
||||
|
||||
// modifies a preceding LLAMA_GRETYPE_CHAR or LLAMA_GRETYPE_CHAR_ALT to
|
||||
// be an inclusive range ([a-z])
|
||||
LLAMA_GRETYPE_CHAR_RNG_UPPER = 5,
|
||||
|
||||
// modifies a preceding LLAMA_GRETYPE_CHAR or
|
||||
// LLAMA_GRETYPE_CHAR_RNG_UPPER to add an alternate char to match ([ab], [a-zA])
|
||||
LLAMA_GRETYPE_CHAR_ALT = 6,
|
||||
|
||||
// any character (.)
|
||||
LLAMA_GRETYPE_CHAR_ANY = 7,
|
||||
};
|
||||
|
||||
typedef struct llama_grammar_element {
|
||||
enum llama_gretype type;
|
||||
uint32_t value; // Unicode code point or rule ID
|
||||
} llama_grammar_element;
|
||||
|
||||
struct llama_partial_utf8 {
|
||||
uint32_t value; // bit value so far (unshifted)
|
||||
int n_remain; // num bytes remaining; -1 indicates invalid sequence
|
||||
};
|
||||
|
||||
struct llama_grammar_candidate {
|
||||
size_t index;
|
||||
const uint32_t * code_points;
|
||||
llama_partial_utf8 partial_utf8;
|
||||
};
|
||||
|
||||
using llama_grammar_rule = std::vector< llama_grammar_element>;
|
||||
using llama_grammar_stack = std::vector<const llama_grammar_element *>;
|
||||
|
||||
using llama_grammar_rules = std::vector<llama_grammar_rule>;
|
||||
using llama_grammar_stacks = std::vector<llama_grammar_stack>;
|
||||
using llama_grammar_candidates = std::vector<llama_grammar_candidate>;
|
||||
|
||||
const llama_grammar_rules & llama_grammar_get_rules (const struct llama_grammar * grammar);
|
||||
llama_grammar_stacks & llama_grammar_get_stacks( struct llama_grammar * grammar);
|
||||
|
||||
// Decodes a UTF-8 string which may end in an incomplete sequence. Adds a terminating 0 for use as
|
||||
// pointer. If an invalid sequence is encountered, returns `llama_partial_utf8.n_remain == -1`.
|
||||
std::pair<std::vector<uint32_t>, llama_partial_utf8> decode_utf8(
|
||||
const std::string & src,
|
||||
llama_partial_utf8 partial_start);
|
||||
|
||||
// takes a set of possible pushdown stacks on a grammar, which are required to
|
||||
// be positioned at a character range (see `llama_grammar_advance_stack`), and
|
||||
// produces the N possible stacks if the given char is accepted at those
|
||||
// positions
|
||||
void llama_grammar_accept(
|
||||
const llama_grammar_rules & rules,
|
||||
const llama_grammar_stacks & stacks,
|
||||
const uint32_t chr,
|
||||
llama_grammar_stacks & new_stacks);
|
||||
|
||||
std::vector<llama_grammar_candidate> llama_grammar_reject_candidates_for_stack(
|
||||
const llama_grammar_rules & rules,
|
||||
const llama_grammar_stack & stack,
|
||||
const llama_grammar_candidates & candidates);
|
||||
|
||||
struct llama_grammar_parser {
|
||||
std::map<std::string, uint32_t> symbol_ids;
|
||||
|
||||
llama_grammar_rules rules;
|
||||
|
||||
llama_grammar_stack c_rules() const;
|
||||
|
||||
uint32_t get_symbol_id(const char * src, size_t len);
|
||||
uint32_t generate_symbol_id(const std::string & base_name);
|
||||
|
||||
void add_rule(uint32_t rule_id, const llama_grammar_rule & rule);
|
||||
|
||||
const char * parse_alternates(
|
||||
const char * src,
|
||||
const std::string & rule_name,
|
||||
uint32_t rule_id,
|
||||
bool is_nested);
|
||||
|
||||
const char * parse_sequence(
|
||||
const char * src,
|
||||
const std::string & rule_name,
|
||||
llama_grammar_rule & rule,
|
||||
bool is_nested);
|
||||
|
||||
const char * parse_rule(const char * src);
|
||||
|
||||
bool parse(const char * src);
|
||||
void print(FILE * file);
|
||||
};
|
||||
|
||||
struct llama_grammar {
|
||||
const llama_grammar_rules rules;
|
||||
const llama_vocab & vocab;
|
||||
|
||||
const llama_grammar_rules rules; // TODO: shared ptr
|
||||
llama_grammar_stacks stacks;
|
||||
|
||||
// buffer for partially generated UTF-8 sequence from accepted tokens
|
||||
llama_partial_utf8 partial_utf8;
|
||||
|
||||
mutable int64_t t_total_us;
|
||||
|
||||
mutable int32_t n_sample;
|
||||
mutable int32_t n_accept;
|
||||
};
|
||||
|
||||
//
|
||||
// internal API
|
||||
//
|
||||
|
||||
// TODO: temporary until the tests are fixed
|
||||
struct llama_grammar * llama_grammar_init_impl(
|
||||
const struct llama_vocab & vocab,
|
||||
const llama_grammar_element ** rules,
|
||||
size_t n_rules,
|
||||
size_t start_rule_index);
|
||||
|
||||
struct llama_grammar * llama_grammar_init_impl(const struct llama_vocab & vocab, const char * grammar_str, const char * grammar_root);
|
||||
|
||||
void llama_grammar_free_impl(struct llama_grammar * grammar);
|
||||
|
||||
struct llama_grammar * llama_grammar_copy_impl(const struct llama_grammar * grammar);
|
||||
struct llama_grammar * llama_grammar_copy_impl(const struct llama_grammar & grammar);
|
||||
|
||||
void llama_grammar_sample_impl(
|
||||
const struct llama_grammar * grammar,
|
||||
const struct llama_vocab * vocab,
|
||||
const struct llama_sampling * smpl,
|
||||
// TODO: move the API below as member functions of llama_grammar
|
||||
void llama_grammar_apply_impl(
|
||||
const struct llama_grammar & grammar,
|
||||
llama_token_data_array * candidates);
|
||||
|
||||
void llama_grammar_accept_token_impl(
|
||||
struct llama_grammar * grammar,
|
||||
const struct llama_vocab * vocab,
|
||||
const struct llama_sampling * smpl,
|
||||
void llama_grammar_accept_impl(
|
||||
struct llama_grammar & grammar,
|
||||
llama_token token);
|
||||
|
|
|
@ -1,8 +1,11 @@
|
|||
#pragma once
|
||||
|
||||
#define LLAMA_API_INTERNAL
|
||||
#include "llama.h"
|
||||
|
||||
#include <random>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#ifdef __GNUC__
|
||||
#ifdef __MINGW32__
|
||||
#define LLAMA_ATTRIBUTE_FORMAT(...) __attribute__((format(gnu_printf, __VA_ARGS__)))
|
||||
|
@ -39,3 +42,7 @@ static void replace_all(std::string & s, const std::string & search, const std::
|
|||
pos += replace.length();
|
||||
}
|
||||
}
|
||||
|
||||
const std::vector<std::pair<std::string, struct ggml_tensor *>> & llama_internal_get_tensor_map(
|
||||
struct llama_context * ctx
|
||||
);
|
||||
|
|
|
@ -1,5 +1,8 @@
|
|||
#include "llama-sampling.h"
|
||||
|
||||
#include "llama-vocab.h"
|
||||
#include "llama-grammar.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <cstring>
|
||||
#include <ctime>
|
||||
|
@ -21,19 +24,66 @@ static void llama_log_softmax(float * array, size_t size) {
|
|||
}
|
||||
}
|
||||
|
||||
void llama_set_rng_seed_impl(struct llama_sampling * smpl, uint32_t seed) {
|
||||
llama_sampling::llama_sampling(uint32_t n_vocab) : n_vocab(n_vocab) {
|
||||
}
|
||||
|
||||
llama_sampling::llama_sampling(const struct llama_vocab & vocab, const char * grammar_str, const char * grammar_root) : n_vocab(vocab.n_vocab) {
|
||||
if (grammar_str != nullptr && grammar_str[0] != '\0') {
|
||||
grammar = llama_grammar_init_impl(vocab, grammar_str, grammar_root);
|
||||
}
|
||||
}
|
||||
|
||||
llama_sampling::~llama_sampling() {
|
||||
if (grammar) {
|
||||
llama_grammar_free_impl(grammar);
|
||||
}
|
||||
}
|
||||
|
||||
struct llama_sampling * llama_sampling_init_impl(const struct llama_vocab & vocab, const char * grammar_str, const char * grammar_root) {
|
||||
return new llama_sampling(vocab, grammar_str, grammar_root);
|
||||
}
|
||||
|
||||
void llama_sampling_free_impl(struct llama_sampling * sampling) {
|
||||
delete sampling;
|
||||
}
|
||||
|
||||
struct llama_sampling * llama_sampling_cp_impl(const struct llama_sampling & smpl) {
|
||||
auto * result = new llama_sampling(smpl.n_vocab);
|
||||
|
||||
if (smpl.grammar) {
|
||||
result->grammar = llama_grammar_copy_impl(*smpl.grammar);
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
void llama_sampling_reset_impl(struct llama_sampling & smpl, const char * grammar_str, const char * grammar_root) {
|
||||
// TODO: this is dumb, need to fix
|
||||
const struct llama_vocab * vocab = nullptr;
|
||||
|
||||
if (smpl.grammar) {
|
||||
vocab = &smpl.grammar->vocab;
|
||||
|
||||
llama_grammar_free_impl(smpl.grammar);
|
||||
smpl.grammar = nullptr;
|
||||
}
|
||||
|
||||
if (grammar_str != nullptr && grammar_str[0] != '\0') {
|
||||
smpl.grammar = llama_grammar_init_impl(*vocab, grammar_str, grammar_root);
|
||||
}
|
||||
}
|
||||
|
||||
void llama_sampling_set_rng_seed_impl(struct llama_sampling & smpl, uint32_t seed) {
|
||||
if (seed == LLAMA_DEFAULT_SEED) {
|
||||
seed = time(NULL);
|
||||
}
|
||||
|
||||
smpl->rng.seed(seed);
|
||||
smpl.rng.seed(seed);
|
||||
}
|
||||
|
||||
void llama_sample_softmax_impl(struct llama_sampling * smpl, llama_token_data_array * candidates) {
|
||||
void llama_sampling_softmax_impl(struct llama_sampling & /*smpl*/, llama_token_data_array * candidates) {
|
||||
GGML_ASSERT(candidates->size > 0);
|
||||
|
||||
const int64_t t_start_sample_us = ggml_time_us();
|
||||
|
||||
// Sort the logits in descending order
|
||||
if (!candidates->sorted) {
|
||||
std::sort(candidates->data, candidates->data + candidates->size, [](const llama_token_data & a, const llama_token_data & b) {
|
||||
|
@ -44,28 +94,24 @@ void llama_sample_softmax_impl(struct llama_sampling * smpl, llama_token_data_ar
|
|||
|
||||
float max_l = candidates->data[0].logit;
|
||||
float cum_sum = 0.0f;
|
||||
|
||||
for (size_t i = 0; i < candidates->size; ++i) {
|
||||
float p = expf(candidates->data[i].logit - max_l);
|
||||
candidates->data[i].p = p;
|
||||
cum_sum += p;
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < candidates->size; ++i) {
|
||||
candidates->data[i].p /= cum_sum;
|
||||
}
|
||||
|
||||
if (smpl) {
|
||||
smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
|
||||
}
|
||||
}
|
||||
|
||||
void llama_sample_top_k_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, int32_t k, size_t min_keep) {
|
||||
void llama_sampling_top_k_impl(struct llama_sampling & /*smpl*/, llama_token_data_array * candidates, int32_t k, size_t min_keep) {
|
||||
// TODO: move bucket sort to separate function so that top_p/tail_free/typical/softmax first is equally fast
|
||||
// if (k >= (int32_t)candidates->size) {
|
||||
// return;
|
||||
// }
|
||||
|
||||
const int64_t t_start_sample_us = ggml_time_us();
|
||||
|
||||
if (k <= 0) {
|
||||
k = candidates->size;
|
||||
}
|
||||
|
@ -133,20 +179,14 @@ void llama_sample_top_k_impl(struct llama_sampling * smpl, llama_token_data_arra
|
|||
candidates->sorted = true;
|
||||
}
|
||||
candidates->size = k;
|
||||
|
||||
if (smpl) {
|
||||
smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
|
||||
}
|
||||
}
|
||||
|
||||
void llama_sample_top_p_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float p, size_t min_keep) {
|
||||
void llama_sampling_top_p_impl(struct llama_sampling & smpl, llama_token_data_array * candidates, float p, size_t min_keep) {
|
||||
if (p >= 1.0f) {
|
||||
return;
|
||||
}
|
||||
|
||||
llama_sample_softmax_impl(smpl, candidates);
|
||||
|
||||
const int64_t t_start_sample_us = ggml_time_us();
|
||||
llama_sampling_softmax_impl(smpl, candidates);
|
||||
|
||||
// Compute the cumulative probabilities
|
||||
float cum_sum = 0.0f;
|
||||
|
@ -165,19 +205,13 @@ void llama_sample_top_p_impl(struct llama_sampling * smpl, llama_token_data_arra
|
|||
|
||||
// Resize the output vector to keep only the top-p tokens
|
||||
candidates->size = last_idx;
|
||||
|
||||
if (smpl) {
|
||||
smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
|
||||
}
|
||||
}
|
||||
|
||||
void llama_sample_min_p_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float p, size_t min_keep) {
|
||||
void llama_sampling_min_p_impl(struct llama_sampling & /*smpl*/, llama_token_data_array * candidates, float p, size_t min_keep) {
|
||||
if (p <= 0.0f || !candidates->size) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int64_t t_start_sample_us = ggml_time_us();
|
||||
|
||||
bool min_p_applied = false;
|
||||
|
||||
// if the candidates aren't sorted, try the unsorted implementation first
|
||||
|
@ -226,19 +260,14 @@ void llama_sample_min_p_impl(struct llama_sampling * smpl, llama_token_data_arra
|
|||
// Resize the output vector to keep only the matching tokens
|
||||
candidates->size = i;
|
||||
}
|
||||
|
||||
if (smpl) {
|
||||
smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
|
||||
}
|
||||
}
|
||||
|
||||
void llama_sample_tail_free_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float z, size_t min_keep) {
|
||||
void llama_sampling_tail_free_impl(struct llama_sampling & smpl, llama_token_data_array * candidates, float z, size_t min_keep) {
|
||||
if (z >= 1.0f || candidates->size <= 2) {
|
||||
return;
|
||||
}
|
||||
|
||||
llama_sample_softmax_impl((struct llama_sampling *) nullptr, candidates);
|
||||
const int64_t t_start_sample_us = ggml_time_us();
|
||||
llama_sampling_softmax_impl(smpl, candidates);
|
||||
|
||||
// Compute the first and second derivatives
|
||||
std::vector<float> first_derivatives(candidates->size - 1);
|
||||
|
@ -285,13 +314,9 @@ void llama_sample_tail_free_impl(struct llama_sampling * smpl, llama_token_data_
|
|||
|
||||
// Resize the output vector to keep only the tokens above the tail location
|
||||
candidates->size = last_idx;
|
||||
|
||||
if (smpl) {
|
||||
smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
|
||||
}
|
||||
}
|
||||
|
||||
void llama_sample_typical_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float p, size_t min_keep) {
|
||||
void llama_sampling_typical_impl(struct llama_sampling & smpl, llama_token_data_array * candidates, float p, size_t min_keep) {
|
||||
// Reference implementation:
|
||||
// https://github.com/huggingface/transformers/compare/main...cimeister:typical-sampling:typical-pr
|
||||
if (p >= 1.0f) {
|
||||
|
@ -299,9 +324,7 @@ void llama_sample_typical_impl(struct llama_sampling * smpl, llama_token_data_ar
|
|||
}
|
||||
|
||||
// Compute the softmax of logits and calculate entropy
|
||||
llama_sample_softmax_impl((struct llama_sampling *) nullptr, candidates);
|
||||
|
||||
const int64_t t_start_sample_us = ggml_time_us();
|
||||
llama_sampling_softmax_impl(smpl, candidates);
|
||||
|
||||
float entropy = 0.0f;
|
||||
for (size_t i = 0; i < candidates->size; ++i) {
|
||||
|
@ -349,15 +372,9 @@ void llama_sample_typical_impl(struct llama_sampling * smpl, llama_token_data_ar
|
|||
std::copy(new_candidates.begin(), new_candidates.end(), candidates->data);
|
||||
candidates->size = new_candidates.size();
|
||||
candidates->sorted = false;
|
||||
|
||||
if (smpl) {
|
||||
smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
|
||||
}
|
||||
}
|
||||
|
||||
void llama_sample_entropy_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float min_temp, float max_temp, float exponent_val) {
|
||||
const int64_t t_start_sample_us = ggml_time_us();
|
||||
|
||||
void llama_sampling_entropy_impl(struct llama_sampling & smpl, llama_token_data_array * candidates, float min_temp, float max_temp, float exponent_val) {
|
||||
// no need to do anything if there is only one (or zero) candidates
|
||||
if(candidates->size <= 1) {
|
||||
return;
|
||||
|
@ -366,7 +383,7 @@ void llama_sample_entropy_impl(struct llama_sampling * smpl, llama_token_data_ar
|
|||
// Calculate maximum possible entropy
|
||||
float max_entropy = -logf(1.0f / candidates->size);
|
||||
|
||||
llama_sample_softmax_impl((struct llama_sampling *) nullptr, candidates);
|
||||
llama_sampling_softmax_impl(smpl, candidates);
|
||||
|
||||
// Calculate entropy of the softmax probabilities
|
||||
float entropy = 0.0f;
|
||||
|
@ -416,26 +433,22 @@ void llama_sample_entropy_impl(struct llama_sampling * smpl, llama_token_data_ar
|
|||
LLAMA_LOG_INFO("Token %zu: %f%%\n", i + 1, candidates->data[i].p * 100.0f);
|
||||
}
|
||||
#endif
|
||||
|
||||
if (smpl) {
|
||||
smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
|
||||
}
|
||||
}
|
||||
|
||||
void llama_sample_temp_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float temp) {
|
||||
const int64_t t_start_sample_us = ggml_time_us();
|
||||
|
||||
void llama_sampling_temp_impl(struct llama_sampling & /*smpl*/, llama_token_data_array * candidates, float temp) {
|
||||
for (size_t i = 0; i < candidates->size; ++i) {
|
||||
candidates->data[i].logit /= temp;
|
||||
}
|
||||
}
|
||||
|
||||
if (smpl) {
|
||||
smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
|
||||
void llama_sampling_grammar_impl(struct llama_sampling & smpl, llama_token_data_array * candidates) {
|
||||
if (smpl.grammar) {
|
||||
llama_grammar_apply_impl(*smpl.grammar, candidates);
|
||||
}
|
||||
}
|
||||
|
||||
void llama_sample_repetition_penalties_impl(
|
||||
struct llama_sampling * smpl,
|
||||
void llama_sampling_repetition_penalties_impl(
|
||||
struct llama_sampling & /*smpl*/,
|
||||
llama_token_data_array * candidates,
|
||||
const llama_token * last_tokens,
|
||||
size_t penalty_last_n,
|
||||
|
@ -446,8 +459,6 @@ void llama_sample_repetition_penalties_impl(
|
|||
return;
|
||||
}
|
||||
|
||||
const int64_t t_start_sample_us = ggml_time_us();
|
||||
|
||||
// Create a frequency map to count occurrences of each token in last_tokens
|
||||
std::unordered_map<llama_token, int> token_count;
|
||||
for (size_t i = 0; i < penalty_last_n; ++i) {
|
||||
|
@ -475,43 +486,30 @@ void llama_sample_repetition_penalties_impl(
|
|||
}
|
||||
|
||||
candidates->sorted = false;
|
||||
|
||||
if (smpl) {
|
||||
smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
|
||||
}
|
||||
}
|
||||
|
||||
void llama_sample_apply_guidance_impl(
|
||||
struct llama_sampling * smpl,
|
||||
void llama_sampling_apply_guidance_impl(
|
||||
struct llama_sampling & smpl,
|
||||
float * logits,
|
||||
float * logits_guidance,
|
||||
float scale) {
|
||||
GGML_ASSERT(smpl);
|
||||
|
||||
const auto t_start_sample_us = ggml_time_us();
|
||||
const auto n_vocab = smpl->n_vocab;
|
||||
const auto n_vocab = smpl.n_vocab;
|
||||
|
||||
llama_log_softmax(logits, n_vocab);
|
||||
llama_log_softmax(logits_guidance, n_vocab);
|
||||
|
||||
for (int i = 0; i < n_vocab; ++i) {
|
||||
for (uint32_t i = 0; i < n_vocab; ++i) {
|
||||
auto & l = logits[i];
|
||||
const auto & g = logits_guidance[i];
|
||||
|
||||
l = scale * (l - g) + g;
|
||||
}
|
||||
|
||||
smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
|
||||
}
|
||||
|
||||
llama_token llama_sample_token_mirostat_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float tau, float eta, int32_t m, float * mu) {
|
||||
GGML_ASSERT(smpl);
|
||||
llama_token llama_sampling_sample_mirostat_impl(struct llama_sampling & smpl, llama_token_data_array * candidates, float tau, float eta, int32_t m, float * mu) {
|
||||
const int32_t n_vocab = float(smpl.n_vocab);
|
||||
|
||||
const int32_t n_vocab = float(smpl->n_vocab);
|
||||
|
||||
int64_t t_start_sample_us = ggml_time_us();
|
||||
|
||||
llama_sample_softmax_impl((struct llama_sampling *) nullptr, candidates);
|
||||
llama_sampling_softmax_impl(smpl, candidates);
|
||||
|
||||
// Estimate s_hat using the most probable m tokens
|
||||
float s_hat = 0.0;
|
||||
|
@ -530,10 +528,8 @@ llama_token llama_sample_token_mirostat_impl(struct llama_sampling * smpl, llama
|
|||
float k = powf((epsilon_hat * powf(2, *mu)) / (1 - powf(n_vocab, -epsilon_hat)), 1 / s_hat);
|
||||
|
||||
// Sample the next word X using top-k sampling
|
||||
llama_sample_top_k_impl((struct llama_sampling *) nullptr, candidates, int(k), 1);
|
||||
smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
|
||||
llama_token X = llama_sample_token_impl(smpl, candidates);
|
||||
t_start_sample_us = ggml_time_us();
|
||||
llama_sampling_top_k_impl(smpl, candidates, int(k), 1);
|
||||
llama_token X = llama_sampling_sample_impl(smpl, candidates);
|
||||
|
||||
// Compute error as the difference between observed surprise and target surprise value
|
||||
size_t X_idx = std::distance(candidates->data, std::find_if(candidates->data, candidates->data + candidates->size, [&](const llama_token_data & candidate) {
|
||||
|
@ -545,15 +541,11 @@ llama_token llama_sample_token_mirostat_impl(struct llama_sampling * smpl, llama
|
|||
// Update mu using the learning rate and error
|
||||
*mu = *mu - eta * e;
|
||||
|
||||
smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
|
||||
return X;
|
||||
}
|
||||
|
||||
llama_token llama_sample_token_mirostat_v2_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float tau, float eta, float * mu) {
|
||||
int64_t t_start_sample_us;
|
||||
t_start_sample_us = ggml_time_us();
|
||||
|
||||
llama_sample_softmax_impl(smpl, candidates);
|
||||
llama_token llama_sampling_sample_mirostat_v2_impl(struct llama_sampling & smpl, llama_token_data_array * candidates, float tau, float eta, float * mu) {
|
||||
llama_sampling_softmax_impl(smpl, candidates);
|
||||
|
||||
// Truncate the words with surprise values greater than mu
|
||||
candidates->size = std::distance(candidates->data, std::find_if(candidates->data, candidates->data + candidates->size, [&](const llama_token_data & candidate) {
|
||||
|
@ -564,16 +556,11 @@ llama_token llama_sample_token_mirostat_v2_impl(struct llama_sampling * smpl, ll
|
|||
candidates->size = 1;
|
||||
}
|
||||
|
||||
if (smpl) {
|
||||
smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
|
||||
}
|
||||
|
||||
// Normalize the probabilities of the remaining words
|
||||
llama_sample_softmax_impl(smpl, candidates);
|
||||
llama_sampling_softmax_impl(smpl, candidates);
|
||||
|
||||
// Sample the next word X from the remaining words
|
||||
llama_token X = llama_sample_token_impl(smpl, candidates);
|
||||
t_start_sample_us = ggml_time_us();
|
||||
llama_token X = llama_sampling_sample_impl(smpl, candidates);
|
||||
|
||||
// Compute error as the difference between observed surprise and target surprise value
|
||||
size_t X_idx = std::distance(candidates->data, std::find_if(candidates->data, candidates->data + candidates->size, [&](const llama_token_data & candidate) {
|
||||
|
@ -585,33 +572,22 @@ llama_token llama_sample_token_mirostat_v2_impl(struct llama_sampling * smpl, ll
|
|||
// Update mu using the learning rate and error
|
||||
*mu = *mu - eta * e;
|
||||
|
||||
if (smpl) {
|
||||
smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
|
||||
}
|
||||
return X;
|
||||
}
|
||||
|
||||
llama_token llama_sample_token_greedy_impl(struct llama_sampling * smpl, llama_token_data_array * candidates) {
|
||||
const int64_t t_start_sample_us = ggml_time_us();
|
||||
|
||||
llama_token llama_sampling_sample_greedy_impl(struct llama_sampling & /*smpl*/, llama_token_data_array * candidates) {
|
||||
// Find max element
|
||||
auto * max_iter = std::max_element(candidates->data, candidates->data + candidates->size, [](const llama_token_data & a, const llama_token_data & b) {
|
||||
return a.logit < b.logit;
|
||||
});
|
||||
|
||||
llama_token result = max_iter->id;
|
||||
if (smpl) {
|
||||
smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
|
||||
smpl->n_sample++;
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
llama_token llama_sample_token_with_rng_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, std::mt19937 & rng) {
|
||||
GGML_ASSERT(smpl);
|
||||
|
||||
const int64_t t_start_sample_us = ggml_time_us();
|
||||
llama_sample_softmax_impl((struct llama_sampling *) nullptr, candidates);
|
||||
llama_token llama_sampling_sample_with_rng_impl(struct llama_sampling & smpl, llama_token_data_array * candidates, std::mt19937 & rng) {
|
||||
llama_sampling_softmax_impl(smpl, candidates);
|
||||
|
||||
std::vector<float> probs;
|
||||
probs.reserve(candidates->size);
|
||||
|
@ -624,12 +600,17 @@ llama_token llama_sample_token_with_rng_impl(struct llama_sampling * smpl, llama
|
|||
|
||||
llama_token result = candidates->data[idx].id;
|
||||
|
||||
smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
|
||||
smpl->n_sample++;
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
llama_token llama_sample_token_impl(struct llama_sampling * smpl, llama_token_data_array * candidates) {
|
||||
return llama_sample_token_with_rng_impl(smpl, candidates, smpl->rng);
|
||||
llama_token llama_sampling_sample_impl(struct llama_sampling & smpl, llama_token_data_array * candidates) {
|
||||
return llama_sampling_sample_with_rng_impl(smpl, candidates, smpl.rng);
|
||||
}
|
||||
|
||||
void llama_sampling_accept_impl(struct llama_sampling & smpl, llama_token token) {
|
||||
// TODO: implement token storing in history
|
||||
|
||||
if (smpl.grammar) {
|
||||
llama_grammar_accept_impl(*smpl.grammar, token);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,40 +1,54 @@
|
|||
#pragma once
|
||||
|
||||
#include "llama-impl.h"
|
||||
#include "llama-grammar.h"
|
||||
|
||||
struct llama_vocab;
|
||||
struct llama_grammar;
|
||||
|
||||
struct llama_sampling {
|
||||
llama_sampling(int32_t n_vocab) : n_vocab(n_vocab) {}
|
||||
llama_sampling(uint32_t n_vocab);
|
||||
llama_sampling(const struct llama_vocab & vocab, const char * grammar_str, const char * grammar_root);
|
||||
~llama_sampling();
|
||||
|
||||
const uint32_t n_vocab;
|
||||
|
||||
std::mt19937 rng;
|
||||
|
||||
int32_t n_vocab = 0;
|
||||
struct llama_grammar * grammar = nullptr;
|
||||
|
||||
mutable int64_t t_total_us = 0;
|
||||
|
||||
mutable int64_t t_sample_us = 0;
|
||||
mutable int32_t n_sample = 0;
|
||||
|
||||
void reset_timings() const {
|
||||
t_sample_us = 0;
|
||||
n_sample = 0;
|
||||
}
|
||||
};
|
||||
|
||||
//
|
||||
// internal API
|
||||
//
|
||||
|
||||
void llama_set_rng_seed_impl(struct llama_sampling * smpl, uint32_t seed);
|
||||
struct llama_sampling * llama_sampling_init_impl(const struct llama_vocab & vocab, const char * grammar_str, const char * grammar_root);
|
||||
|
||||
void llama_sample_softmax_impl (struct llama_sampling * smpl, llama_token_data_array * candidates);
|
||||
void llama_sample_top_k_impl (struct llama_sampling * smpl, llama_token_data_array * candidates, int32_t k, size_t min_keep);
|
||||
void llama_sample_top_p_impl (struct llama_sampling * smpl, llama_token_data_array * candidates, float p, size_t min_keep);
|
||||
void llama_sample_min_p_impl (struct llama_sampling * smpl, llama_token_data_array * candidates, float p, size_t min_keep);
|
||||
void llama_sample_tail_free_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float z, size_t min_keep);
|
||||
void llama_sample_typical_impl (struct llama_sampling * smpl, llama_token_data_array * candidates, float p, size_t min_keep);
|
||||
void llama_sample_entropy_impl (struct llama_sampling * smpl, llama_token_data_array * candidates, float min_temp, float max_temp, float exponent_val);
|
||||
void llama_sample_temp_impl (struct llama_sampling * smpl, llama_token_data_array * candidates, float temp);
|
||||
void llama_sampling_free_impl(struct llama_sampling * sampling);
|
||||
|
||||
void llama_sample_repetition_penalties_impl(
|
||||
struct llama_sampling * smpl,
|
||||
struct llama_sampling * llama_sampling_cp_impl(const struct llama_sampling & smpl);
|
||||
|
||||
void llama_sampling_reset_impl(struct llama_sampling & smpl, const char * grammar_str, const char * grammar_root);
|
||||
|
||||
// TODO: move the API below as member functions of llama_sampling
|
||||
void llama_sampling_set_rng_seed_impl(struct llama_sampling & smpl, uint32_t seed);
|
||||
|
||||
void llama_sampling_softmax_impl (struct llama_sampling & smpl, llama_token_data_array * candidates);
|
||||
void llama_sampling_top_k_impl (struct llama_sampling & smpl, llama_token_data_array * candidates, int32_t k, size_t min_keep);
|
||||
void llama_sampling_top_p_impl (struct llama_sampling & smpl, llama_token_data_array * candidates, float p, size_t min_keep);
|
||||
void llama_sampling_min_p_impl (struct llama_sampling & smpl, llama_token_data_array * candidates, float p, size_t min_keep);
|
||||
void llama_sampling_tail_free_impl(struct llama_sampling & smpl, llama_token_data_array * candidates, float z, size_t min_keep);
|
||||
void llama_sampling_typical_impl (struct llama_sampling & smpl, llama_token_data_array * candidates, float p, size_t min_keep);
|
||||
void llama_sampling_entropy_impl (struct llama_sampling & smpl, llama_token_data_array * candidates, float min_temp, float max_temp, float exponent_val);
|
||||
void llama_sampling_temp_impl (struct llama_sampling & smpl, llama_token_data_array * candidates, float temp);
|
||||
void llama_sampling_grammar_impl (struct llama_sampling & smpl, llama_token_data_array * candidates);
|
||||
|
||||
void llama_sampling_repetition_penalties_impl(
|
||||
struct llama_sampling & smpl,
|
||||
llama_token_data_array * candidates,
|
||||
const llama_token * last_tokens,
|
||||
size_t penalty_last_n,
|
||||
|
@ -42,15 +56,16 @@ void llama_sample_repetition_penalties_impl(
|
|||
float penalty_freq,
|
||||
float penalty_present);
|
||||
|
||||
void llama_sample_apply_guidance_impl(
|
||||
struct llama_sampling * smpl,
|
||||
void llama_sampling_apply_guidance_impl(
|
||||
struct llama_sampling & smpl,
|
||||
float * logits,
|
||||
float * logits_guidance,
|
||||
float scale);
|
||||
|
||||
llama_token llama_sample_token_mirostat_impl (struct llama_sampling * smpl, llama_token_data_array * candidates, float tau, float eta, int32_t m, float * mu);
|
||||
llama_token llama_sample_token_mirostat_v2_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float tau, float eta, float * mu);
|
||||
llama_token llama_sample_token_greedy_impl (struct llama_sampling * smpl, llama_token_data_array * candidates);
|
||||
llama_token llama_sample_token_with_rng_impl (struct llama_sampling * smpl, llama_token_data_array * candidates, std::mt19937 & rng);
|
||||
llama_token llama_sample_token_impl (struct llama_sampling * smpl, llama_token_data_array * candidates);
|
||||
llama_token llama_sampling_sample_mirostat_impl (struct llama_sampling & smpl, llama_token_data_array * candidates, float tau, float eta, int32_t m, float * mu);
|
||||
llama_token llama_sampling_sample_mirostat_v2_impl(struct llama_sampling & smpl, llama_token_data_array * candidates, float tau, float eta, float * mu);
|
||||
llama_token llama_sampling_sample_greedy_impl (struct llama_sampling & smpl, llama_token_data_array * candidates);
|
||||
llama_token llama_sampling_sample_with_rng_impl (struct llama_sampling & smpl, llama_token_data_array * candidates, std::mt19937 & rng);
|
||||
llama_token llama_sampling_sample_impl (struct llama_sampling & smpl, llama_token_data_array * candidates);
|
||||
|
||||
void llama_sampling_accept_impl(struct llama_sampling & smpl, llama_token token);
|
||||
|
|
|
@ -18,6 +18,8 @@ struct llama_vocab {
|
|||
tattr attr;
|
||||
};
|
||||
|
||||
uint32_t n_vocab = 0; // TODO: not great because has to keep in sync with hparams.n_vocab
|
||||
|
||||
enum llama_vocab_type type = LLAMA_VOCAB_TYPE_SPM;
|
||||
enum llama_vocab_pre_type type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT;
|
||||
|
||||
|
@ -62,8 +64,6 @@ struct llama_vocab {
|
|||
int find_bpe_rank(const std::string & token_left, const std::string & token_right) const;
|
||||
};
|
||||
|
||||
const struct llama_vocab * llama_get_vocab(const struct llama_context * ctx);
|
||||
|
||||
//
|
||||
// internal API
|
||||
//
|
||||
|
@ -76,6 +76,7 @@ std::vector<llama_vocab::id> llama_tokenize_internal(
|
|||
bool add_special,
|
||||
bool parse_special = false);
|
||||
|
||||
// TODO: move the API below as member functions of llama_vocab
|
||||
llama_token llama_byte_to_token_impl(const llama_vocab & vocab, uint8_t ch);
|
||||
|
||||
const char * llama_token_get_text_impl(const struct llama_vocab & vocab, llama_token token);
|
||||
|
|
320
src/llama.cpp
320
src/llama.cpp
|
@ -1,6 +1,5 @@
|
|||
#include "llama-impl.h"
|
||||
#include "llama-vocab.h"
|
||||
#include "llama-grammar.h"
|
||||
#include "llama-sampling.h"
|
||||
|
||||
#include "unicode.h"
|
||||
|
@ -148,6 +147,19 @@ static void zeros(std::ofstream & file, size_t n) {
|
|||
}
|
||||
}
|
||||
|
||||
struct time_meas {
|
||||
time_meas(int64_t & t_acc) : t_start_us(ggml_time_us()), t_acc(t_acc) {}
|
||||
|
||||
~time_meas() {
|
||||
t_acc += ggml_time_us() - t_start_us;
|
||||
}
|
||||
|
||||
const int64_t t_start_us;
|
||||
|
||||
int64_t & t_acc;
|
||||
};
|
||||
|
||||
|
||||
LLAMA_ATTRIBUTE_FORMAT(1, 2)
|
||||
static std::string format(const char * fmt, ...) {
|
||||
va_list ap;
|
||||
|
@ -2661,7 +2673,7 @@ struct llama_model {
|
|||
struct llama_context {
|
||||
llama_context(const llama_model & model)
|
||||
: model(model)
|
||||
, sampling(llama_n_vocab(&model))
|
||||
, sampling(model.vocab, nullptr, nullptr) // by default, no grammar
|
||||
, t_start_us(model.t_start_us)
|
||||
, t_load_us(model.t_load_us) {}
|
||||
|
||||
|
@ -2695,16 +2707,16 @@ struct llama_context {
|
|||
|
||||
bool has_evaluated_once = false;
|
||||
|
||||
int64_t t_start_us;
|
||||
int64_t t_load_us;
|
||||
int64_t t_p_eval_us = 0;
|
||||
int64_t t_eval_us = 0;
|
||||
mutable int64_t t_start_us;
|
||||
mutable int64_t t_load_us;
|
||||
mutable int64_t t_p_eval_us = 0;
|
||||
mutable int64_t t_eval_us = 0;
|
||||
|
||||
int64_t t_compute_start_us = 0;
|
||||
int64_t n_queued_tokens = 0;
|
||||
mutable int64_t t_compute_start_us = 0;
|
||||
mutable int64_t n_queued_tokens = 0;
|
||||
|
||||
int32_t n_p_eval = 0; // number of tokens in eval calls for the prompt (with batch size > 1)
|
||||
int32_t n_eval = 0; // number of eval calls
|
||||
mutable int32_t n_p_eval = 0; // number of tokens in eval calls for the prompt (with batch size > 1)
|
||||
mutable int32_t n_eval = 0; // number of eval calls
|
||||
|
||||
// host buffer for the model output (logits and embeddings)
|
||||
ggml_backend_buffer_t buf_output = nullptr;
|
||||
|
@ -5518,6 +5530,7 @@ static void llm_load_vocab(
|
|||
|
||||
const uint32_t n_vocab = gguf_get_arr_n(ctx, token_idx);
|
||||
|
||||
vocab.n_vocab = n_vocab;
|
||||
vocab.id_to_token.resize(n_vocab);
|
||||
|
||||
for (uint32_t i = 0; i < n_vocab; i++) {
|
||||
|
@ -16722,8 +16735,10 @@ struct llama_context * llama_new_context_with_model(
|
|||
ctx->abort_callback = params.abort_callback;
|
||||
ctx->abort_callback_data = params.abort_callback_data;
|
||||
|
||||
ctx->sampling.rng = std::mt19937(params.seed);
|
||||
llama_sampling_set_rng_seed_impl(ctx->sampling, params.seed);
|
||||
|
||||
ctx->logits_all = params.logits_all;
|
||||
|
||||
// build worst-case graph for encoder if a model contains encoder
|
||||
ctx->is_encoding = llama_model_has_encoder(model);
|
||||
|
||||
|
@ -17001,14 +17016,6 @@ void llama_free(struct llama_context * ctx) {
|
|||
delete ctx;
|
||||
}
|
||||
|
||||
const struct llama_model * llama_get_model(const struct llama_context * ctx) {
|
||||
return &ctx->model;
|
||||
}
|
||||
|
||||
const struct llama_vocab * llama_get_vocab(const struct llama_context * ctx) {
|
||||
return &ctx->model.vocab;
|
||||
}
|
||||
|
||||
uint32_t llama_n_ctx(const struct llama_context * ctx) {
|
||||
return ctx->cparams.n_ctx;
|
||||
}
|
||||
|
@ -17029,6 +17036,34 @@ enum llama_vocab_type llama_vocab_type(const struct llama_model * model) {
|
|||
return model->vocab.type;
|
||||
}
|
||||
|
||||
int32_t llama_n_vocab(const struct llama_model * model) {
|
||||
return model->hparams.n_vocab;
|
||||
}
|
||||
|
||||
int32_t llama_n_ctx_train(const struct llama_model * model) {
|
||||
return model->hparams.n_ctx_train;
|
||||
}
|
||||
|
||||
int32_t llama_n_embd(const struct llama_model * model) {
|
||||
return model->hparams.n_embd;
|
||||
}
|
||||
|
||||
int32_t llama_n_layer(const struct llama_model * model) {
|
||||
return model->hparams.n_layer;
|
||||
}
|
||||
|
||||
const struct llama_model * llama_get_model(const struct llama_context * ctx) {
|
||||
return &ctx->model;
|
||||
}
|
||||
|
||||
struct llama_sampling * llama_get_sampling(struct llama_context * ctx) {
|
||||
return &ctx->sampling;
|
||||
}
|
||||
|
||||
enum llama_pooling_type llama_pooling_type(const struct llama_context * ctx) {
|
||||
return ctx->cparams.pooling_type;
|
||||
}
|
||||
|
||||
enum llama_rope_type llama_rope_type(const struct llama_model * model) {
|
||||
switch (model->arch) {
|
||||
// these models do not use RoPE
|
||||
|
@ -17089,26 +17124,6 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) {
|
|||
return LLAMA_ROPE_TYPE_NONE;
|
||||
}
|
||||
|
||||
enum llama_pooling_type llama_pooling_type(const struct llama_context * ctx) {
|
||||
return ctx->cparams.pooling_type;
|
||||
}
|
||||
|
||||
int32_t llama_n_vocab(const struct llama_model * model) {
|
||||
return model->hparams.n_vocab;
|
||||
}
|
||||
|
||||
int32_t llama_n_ctx_train(const struct llama_model * model) {
|
||||
return model->hparams.n_ctx_train;
|
||||
}
|
||||
|
||||
int32_t llama_n_embd(const struct llama_model * model) {
|
||||
return model->hparams.n_embd;
|
||||
}
|
||||
|
||||
int32_t llama_n_layer(const struct llama_model * model) {
|
||||
return model->hparams.n_layer;
|
||||
}
|
||||
|
||||
float llama_rope_freq_scale_train(const struct llama_model * model) {
|
||||
return model->hparams.rope_freq_scale_train;
|
||||
}
|
||||
|
@ -19058,125 +19073,165 @@ int32_t llama_chat_apply_template(
|
|||
return res;
|
||||
}
|
||||
|
||||
//
|
||||
// grammar
|
||||
//
|
||||
|
||||
struct llama_grammar * llama_grammar_init(
|
||||
const llama_grammar_element ** rules,
|
||||
size_t n_rules,
|
||||
size_t start_rule_index) {
|
||||
return llama_grammar_init_impl(rules, n_rules, start_rule_index);
|
||||
}
|
||||
|
||||
void llama_grammar_free(struct llama_grammar * grammar) {
|
||||
llama_grammar_free_impl(grammar);
|
||||
}
|
||||
|
||||
struct llama_grammar * llama_grammar_copy(const struct llama_grammar * grammar) {
|
||||
return llama_grammar_copy_impl(grammar);
|
||||
}
|
||||
|
||||
void llama_grammar_sample(
|
||||
const struct llama_grammar * grammar,
|
||||
const struct llama_context * ctx,
|
||||
llama_token_data_array * candidates) {
|
||||
llama_grammar_sample_impl(grammar, &ctx->model.vocab, &ctx->sampling, candidates);
|
||||
}
|
||||
|
||||
void llama_sample_grammar(
|
||||
struct llama_context * ctx,
|
||||
llama_token_data_array * candidates,
|
||||
const struct llama_grammar * grammar) {
|
||||
llama_grammar_sample(grammar, ctx, candidates);
|
||||
}
|
||||
|
||||
void llama_grammar_accept_token(
|
||||
struct llama_grammar * grammar,
|
||||
struct llama_context * ctx,
|
||||
llama_token token) {
|
||||
llama_grammar_accept_token_impl(grammar, &ctx->model.vocab, &ctx->sampling, token);
|
||||
}
|
||||
|
||||
//
|
||||
// sampling
|
||||
//
|
||||
|
||||
void llama_set_rng_seed(struct llama_context * ctx, uint32_t seed) {
|
||||
llama_set_rng_seed_impl(&ctx->sampling, seed);
|
||||
struct llama_sampling * llama_sampling_init(const struct llama_model * model, const char * grammar_str, const char * grammar_root) {
|
||||
return llama_sampling_init_impl(model->vocab, grammar_str, grammar_root);
|
||||
}
|
||||
|
||||
void llama_sample_softmax(struct llama_context * ctx, llama_token_data_array * candidates) {
|
||||
llama_sample_softmax_impl(ctx ? &ctx->sampling : nullptr, candidates);
|
||||
void llama_sampling_free(struct llama_sampling * smpl) {
|
||||
if (smpl == nullptr) {
|
||||
return;
|
||||
}
|
||||
|
||||
llama_sampling_free_impl(smpl);
|
||||
}
|
||||
|
||||
void llama_sample_top_k(struct llama_context * ctx, llama_token_data_array * candidates, int32_t k, size_t min_keep) {
|
||||
llama_sample_top_k_impl(ctx ? &ctx->sampling : nullptr, candidates, k, min_keep);
|
||||
struct llama_sampling * llama_sampling_cp(const struct llama_sampling * smpl) {
|
||||
return llama_sampling_cp_impl(*smpl);
|
||||
}
|
||||
|
||||
void llama_sample_top_p(struct llama_context * ctx, llama_token_data_array * candidates, float p, size_t min_keep) {
|
||||
llama_sample_top_p_impl(ctx ? &ctx->sampling : nullptr, candidates, p, min_keep);
|
||||
void llama_sampling_reset(struct llama_sampling * smpl, const char * grammar_str, const char * grammar_root) {
|
||||
llama_sampling_reset_impl(*smpl, grammar_str, grammar_root);
|
||||
}
|
||||
|
||||
void llama_sample_min_p(struct llama_context * ctx, llama_token_data_array * candidates, float p, size_t min_keep) {
|
||||
llama_sample_min_p_impl(ctx ? &ctx->sampling : nullptr, candidates, p, min_keep);
|
||||
void llama_sampling_set_rng_seed(struct llama_sampling * smpl, uint32_t seed) {
|
||||
llama_sampling_set_rng_seed_impl(*smpl, seed);
|
||||
}
|
||||
|
||||
void llama_sample_tail_free(struct llama_context * ctx, llama_token_data_array * candidates, float z, size_t min_keep) {
|
||||
llama_sample_tail_free_impl(ctx ? &ctx->sampling : nullptr, candidates, z, min_keep);
|
||||
void llama_sampling_softmax(struct llama_sampling * smpl, llama_token_data_array * candidates) {
|
||||
time_meas tm(smpl->t_total_us);
|
||||
|
||||
llama_sampling_softmax_impl(*smpl, candidates);
|
||||
}
|
||||
|
||||
void llama_sample_typical(struct llama_context * ctx, llama_token_data_array * candidates, float p, size_t min_keep) {
|
||||
llama_sample_typical_impl(ctx ? &ctx->sampling : nullptr, candidates, p, min_keep);
|
||||
void llama_sampling_top_k(struct llama_sampling * smpl, llama_token_data_array * candidates, int32_t k, size_t min_keep) {
|
||||
time_meas tm(smpl->t_total_us);
|
||||
|
||||
llama_sampling_top_k_impl(*smpl, candidates, k, min_keep);
|
||||
}
|
||||
|
||||
void llama_sample_entropy(struct llama_context * ctx, llama_token_data_array * candidates_p, float min_temp, float max_temp, float exponent_val) {
|
||||
llama_sample_entropy_impl(ctx ? &ctx->sampling : nullptr, candidates_p, min_temp, max_temp, exponent_val);
|
||||
void llama_sampling_top_p(struct llama_sampling * smpl, llama_token_data_array * candidates, float p, size_t min_keep) {
|
||||
time_meas tm(smpl->t_total_us);
|
||||
|
||||
llama_sampling_top_p_impl(*smpl, candidates, p, min_keep);
|
||||
}
|
||||
|
||||
void llama_sample_temp(struct llama_context * ctx, llama_token_data_array * candidates_p, float temp) {
|
||||
llama_sample_temp_impl(ctx ? &ctx->sampling : nullptr, candidates_p, temp);
|
||||
void llama_sampling_min_p(struct llama_sampling * smpl, llama_token_data_array * candidates, float p, size_t min_keep) {
|
||||
time_meas tm(smpl->t_total_us);
|
||||
|
||||
llama_sampling_min_p_impl(*smpl, candidates, p, min_keep);
|
||||
}
|
||||
|
||||
void llama_sample_repetition_penalties(
|
||||
struct llama_context * ctx,
|
||||
void llama_sampling_tail_free(struct llama_sampling * smpl, llama_token_data_array * candidates, float z, size_t min_keep) {
|
||||
time_meas tm(smpl->t_total_us);
|
||||
|
||||
llama_sampling_tail_free_impl(*smpl, candidates, z, min_keep);
|
||||
}
|
||||
|
||||
void llama_sampling_typical(struct llama_sampling * smpl, llama_token_data_array * candidates, float p, size_t min_keep) {
|
||||
time_meas tm(smpl->t_total_us);
|
||||
|
||||
llama_sampling_typical_impl(*smpl, candidates, p, min_keep);
|
||||
}
|
||||
|
||||
void llama_sampling_entropy(struct llama_sampling * smpl, llama_token_data_array * candidates_p, float min_temp, float max_temp, float exponent_val) {
|
||||
time_meas tm(smpl->t_total_us);
|
||||
|
||||
llama_sampling_entropy_impl(*smpl, candidates_p, min_temp, max_temp, exponent_val);
|
||||
}
|
||||
|
||||
void llama_sampling_temp(struct llama_sampling * smpl, llama_token_data_array * candidates_p, float temp) {
|
||||
time_meas tm(smpl->t_total_us);
|
||||
|
||||
llama_sampling_temp_impl(*smpl, candidates_p, temp);
|
||||
}
|
||||
|
||||
void llama_sampling_grammar(
|
||||
struct llama_sampling * smpl,
|
||||
llama_token_data_array * candidates) {
|
||||
time_meas tm(smpl->t_total_us); // TODO: measure grammar time separately from sampling
|
||||
|
||||
llama_sampling_grammar_impl(*smpl, candidates);
|
||||
}
|
||||
|
||||
void llama_sampling_repetition_penalties(
|
||||
struct llama_sampling * smpl,
|
||||
llama_token_data_array * candidates,
|
||||
const llama_token * last_tokens,
|
||||
size_t penalty_last_n,
|
||||
float penalty_repeat,
|
||||
float penalty_freq,
|
||||
float penalty_present) {
|
||||
llama_sample_repetition_penalties_impl(ctx ? &ctx->sampling : nullptr, candidates, last_tokens, penalty_last_n, penalty_repeat, penalty_freq, penalty_present);
|
||||
time_meas tm(smpl->t_total_us);
|
||||
|
||||
llama_sampling_repetition_penalties_impl(*smpl, candidates, last_tokens, penalty_last_n, penalty_repeat, penalty_freq, penalty_present);
|
||||
}
|
||||
|
||||
void llama_sample_apply_guidance(
|
||||
struct llama_context * ctx,
|
||||
void llama_sampling_apply_guidance(
|
||||
struct llama_sampling * smpl,
|
||||
float * logits,
|
||||
float * logits_guidance,
|
||||
float scale) {
|
||||
llama_sample_apply_guidance_impl(&ctx->sampling, logits, logits_guidance, scale);
|
||||
time_meas tm(smpl->t_total_us);
|
||||
|
||||
llama_sampling_apply_guidance_impl(*smpl, logits, logits_guidance, scale);
|
||||
}
|
||||
|
||||
llama_token llama_sample_token_mirostat(struct llama_context * ctx, llama_token_data_array * candidates, float tau, float eta, int32_t m, float * mu) {
|
||||
return llama_sample_token_mirostat_impl(&ctx->sampling, candidates, tau, eta, m, mu);
|
||||
llama_token llama_sampling_sample_mirostat(struct llama_sampling * smpl, llama_token_data_array * candidates, float tau, float eta, int32_t m, float * mu) {
|
||||
time_meas tm(smpl->t_total_us);
|
||||
|
||||
auto res = llama_sampling_sample_mirostat_impl(*smpl, candidates, tau, eta, m, mu);
|
||||
|
||||
smpl->n_sample++;
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
llama_token llama_sample_token_mirostat_v2(struct llama_context * ctx, llama_token_data_array * candidates, float tau, float eta, float * mu) {
|
||||
return llama_sample_token_mirostat_v2_impl(ctx ? &ctx->sampling : nullptr, candidates, tau, eta, mu);
|
||||
llama_token llama_sampling_sample_mirostat_v2(struct llama_sampling * smpl, llama_token_data_array * candidates, float tau, float eta, float * mu) {
|
||||
time_meas tm(smpl->t_total_us);
|
||||
|
||||
auto res = llama_sampling_sample_mirostat_v2_impl(*smpl, candidates, tau, eta, mu);
|
||||
|
||||
smpl->n_sample++;
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
llama_token llama_sample_token_greedy(struct llama_context * ctx, llama_token_data_array * candidates) {
|
||||
return llama_sample_token_greedy_impl(ctx ? &ctx->sampling : nullptr, candidates);
|
||||
llama_token llama_sampling_sample_greedy(struct llama_sampling * smpl, llama_token_data_array * candidates) {
|
||||
time_meas tm(smpl->t_total_us);
|
||||
|
||||
auto res = llama_sampling_sample_greedy_impl(*smpl, candidates);
|
||||
|
||||
smpl->n_sample++;
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
llama_token llama_sample_token_with_rng(struct llama_context * ctx, llama_token_data_array * candidates, std::mt19937 & rng) {
|
||||
return llama_sample_token_with_rng_impl(&ctx->sampling, candidates, rng);
|
||||
llama_token llama_sampling_sample(struct llama_sampling * smpl, llama_token_data_array * candidates) {
|
||||
time_meas tm(smpl->t_total_us);
|
||||
|
||||
auto res = llama_sampling_sample_impl(*smpl, candidates);
|
||||
|
||||
smpl->n_sample++;
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
llama_token llama_sample_token(struct llama_context * ctx, llama_token_data_array * candidates) {
|
||||
return llama_sample_token_with_rng_impl(&ctx->sampling, candidates, ctx->sampling.rng);
|
||||
void llama_sampling_accept(
|
||||
struct llama_sampling * smpl,
|
||||
llama_token token) {
|
||||
time_meas tm(smpl->t_total_us); // TODO: measure grammar time separately from sampling
|
||||
|
||||
llama_sampling_accept_impl(*smpl, token);
|
||||
}
|
||||
|
||||
//
|
||||
// model split
|
||||
//
|
||||
|
||||
int llama_split_path(char * split_path, size_t maxlen, const char * path_prefix, int split_no, int split_count) {
|
||||
static const char * const SPLIT_PATH_FORMAT = "%s-%05d-of-%05d.gguf";
|
||||
if (snprintf(split_path, maxlen, SPLIT_PATH_FORMAT, path_prefix, split_no + 1, split_count)) {
|
||||
|
@ -19201,30 +19256,29 @@ int llama_split_prefix(char * dest, size_t maxlen, const char * split_path, int
|
|||
return 0;
|
||||
}
|
||||
|
||||
struct llama_timings llama_get_timings(struct llama_context * ctx) {
|
||||
struct llama_timings result = {
|
||||
void llama_print_timings(struct llama_context * ctx, struct llama_sampling * smpl) {
|
||||
const llama_timings timings = {
|
||||
/*.t_start_ms =*/ 1e-3 * ctx->t_start_us,
|
||||
/*.t_end_ms =*/ 1.00 * ggml_time_ms(),
|
||||
/*.t_load_ms =*/ 1e-3 * ctx->t_load_us,
|
||||
/*.t_sample_ms =*/ 1e-3 * ctx->sampling.t_sample_us,
|
||||
/*.t_sampling_ms =*/ 1e-3 * (smpl ? smpl->t_total_us : ctx->sampling.t_total_us),
|
||||
/*.t_grammar_ms =*/ 1e-3 * (smpl && smpl->grammar ? smpl->grammar->t_total_us : 0.0),
|
||||
/*.t_p_eval_ms =*/ 1e-3 * ctx->t_p_eval_us,
|
||||
/*.t_eval_ms =*/ 1e-3 * ctx->t_eval_us,
|
||||
|
||||
/*.n_sample =*/ std::max(1, ctx->sampling.n_sample),
|
||||
/*.n_sampling =*/ std::max(0, smpl ? smpl->n_sample : ctx->sampling.n_sample),
|
||||
/*.n_grammar_sample =*/ std::max(0, smpl && smpl->grammar ? smpl->grammar->n_sample : 0),
|
||||
/*.n_grammar_accept =*/ std::max(0, smpl && smpl->grammar ? smpl->grammar->n_accept : 0),
|
||||
/*.n_p_eval =*/ std::max(0, ctx->n_p_eval),
|
||||
/*.n_eval =*/ std::max(1, ctx->n_eval),
|
||||
};
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
void llama_print_timings(struct llama_context * ctx) {
|
||||
const llama_timings timings = llama_get_timings(ctx);
|
||||
|
||||
LLAMA_LOG_INFO("\n");
|
||||
LLAMA_LOG_INFO("%s: load time = %10.2f ms\n", __func__, timings.t_load_ms);
|
||||
LLAMA_LOG_INFO("%s: sample time = %10.2f ms / %5d runs (%8.2f ms per token, %8.2f tokens per second)\n",
|
||||
__func__, timings.t_sample_ms, timings.n_sample, timings.t_sample_ms / timings.n_sample, 1e3 / timings.t_sample_ms * timings.n_sample);
|
||||
LLAMA_LOG_INFO("%s: sampling time = %10.2f ms / %5d runs (%8.2f ms per token, %8.2f tokens per second)\n",
|
||||
__func__, timings.t_sampling_ms, timings.n_sampling, timings.t_sampling_ms / timings.n_sampling, 1e3 / timings.t_sampling_ms * timings.n_sampling);
|
||||
LLAMA_LOG_INFO("%s: grammar time = %10.2f ms / %5d runs (%8.2f ms per token, %8.2f tokens per second)\n",
|
||||
__func__, timings.t_grammar_ms, timings.n_grammar_sample, timings.t_grammar_ms / timings.n_grammar_sample, 1e3 / timings.t_grammar_ms * timings.n_grammar_sample);
|
||||
LLAMA_LOG_INFO("%s: prompt eval time = %10.2f ms / %5d tokens (%8.2f ms per token, %8.2f tokens per second)\n",
|
||||
__func__, timings.t_p_eval_ms, timings.n_p_eval, timings.t_p_eval_ms / timings.n_p_eval, 1e3 / timings.t_p_eval_ms * timings.n_p_eval);
|
||||
LLAMA_LOG_INFO("%s: eval time = %10.2f ms / %5d runs (%8.2f ms per token, %8.2f tokens per second)\n",
|
||||
|
@ -19232,12 +19286,18 @@ void llama_print_timings(struct llama_context * ctx) {
|
|||
LLAMA_LOG_INFO("%s: total time = %10.2f ms / %5d tokens\n", __func__, (timings.t_end_ms - timings.t_start_ms), (timings.n_p_eval + timings.n_eval));
|
||||
}
|
||||
|
||||
void llama_reset_timings(struct llama_context * ctx) {
|
||||
void llama_reset_timings(struct llama_context * ctx, struct llama_sampling * smpl) {
|
||||
ctx->t_start_us = ggml_time_us();
|
||||
ctx->t_eval_us = ctx->n_eval = 0;
|
||||
ctx->t_p_eval_us = ctx->n_p_eval = 0;
|
||||
|
||||
ctx->sampling.reset_timings();
|
||||
if (smpl) {
|
||||
smpl->t_total_us = smpl->n_sample = 0;
|
||||
|
||||
if (smpl->grammar) {
|
||||
smpl->grammar->t_total_us = smpl->grammar->n_sample = smpl->grammar->n_accept = 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const char * llama_print_system_info(void) {
|
||||
|
@ -19279,21 +19339,15 @@ void llama_dump_timing_info_yaml(FILE * stream, const llama_context * ctx) {
|
|||
1.0e-3 * ctx->t_eval_us / ctx->n_eval);
|
||||
fprintf(stream, "mst_p_eval: %.2f # ms / token during prompt processing\n",
|
||||
1.0e-3 * ctx->t_p_eval_us / ctx->n_p_eval);
|
||||
fprintf(stream, "mst_sample: %.2f # ms / token during sampling\n",
|
||||
1.0e-3 * ctx->sampling.t_sample_us / ctx->sampling.n_sample);
|
||||
fprintf(stream, "n_eval: %d # number of tokens generated (excluding the first one)\n", ctx->n_eval);
|
||||
fprintf(stream, "n_p_eval: %d # number of tokens processed in batches at the beginning\n", ctx->n_p_eval);
|
||||
fprintf(stream, "n_sample: %d # number of sampled tokens\n", ctx->sampling.n_sample);
|
||||
fprintf(stream, "t_eval_us: %" PRId64 " # total microseconds spent generating tokens\n", ctx->t_eval_us);
|
||||
fprintf(stream, "t_load_us: %" PRId64 " # total microseconds spent loading the model\n", ctx->t_load_us);
|
||||
fprintf(stream, "t_p_eval_us: %" PRId64 " # total microseconds spent prompt processing\n", ctx->t_p_eval_us);
|
||||
fprintf(stream, "t_sample_us: %" PRId64 " # total microseconds spent sampling\n", ctx->sampling.t_sample_us);
|
||||
fprintf(stream, "ts_eval: %.2f # tokens / second during generation\n",
|
||||
1.0e6 * ctx->n_eval / ctx->t_eval_us);
|
||||
fprintf(stream, "ts_p_eval: %.2f # tokens / second during prompt processing\n",
|
||||
1.0e6 * ctx->n_p_eval / ctx->t_p_eval_us);
|
||||
fprintf(stream, "ts_sample: %.2f # tokens / second during sampling\n",
|
||||
1.0e6 * ctx->sampling.n_sample / ctx->sampling.t_sample_us);
|
||||
}
|
||||
|
||||
// For internal test use
|
||||
|
|
|
@ -2,33 +2,23 @@
|
|||
#undef NDEBUG
|
||||
#endif
|
||||
|
||||
#define LLAMA_API_INTERNAL
|
||||
|
||||
#include "ggml.h"
|
||||
#include "llama.h"
|
||||
#include "grammar-parser.h"
|
||||
#include "json-schema-to-grammar.h"
|
||||
#include "llama-vocab.h" // TMP
|
||||
#include "llama-grammar.h"
|
||||
#include "unicode.h"
|
||||
#include "json-schema-to-grammar.h"
|
||||
|
||||
#include <cassert>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
using json = nlohmann::ordered_json;
|
||||
|
||||
static llama_grammar* build_grammar(const std::string & grammar_str) {
|
||||
auto parsed_grammar = grammar_parser::parse(grammar_str.c_str());
|
||||
llama_vocab vocab; // TMP
|
||||
|
||||
// Ensure we parsed correctly
|
||||
assert(!parsed_grammar.rules.empty());
|
||||
|
||||
// Ensure we have a root node
|
||||
assert(!(parsed_grammar.symbol_ids.find("root") == parsed_grammar.symbol_ids.end()));
|
||||
|
||||
std::vector<const llama_grammar_element*> grammar_rules(parsed_grammar.c_rules());
|
||||
llama_grammar* grammar = llama_grammar_init(
|
||||
grammar_rules.data(), grammar_rules.size(), parsed_grammar.symbol_ids.at("root"));
|
||||
|
||||
return grammar;
|
||||
static llama_grammar * build_grammar(const std::string & grammar_str) {
|
||||
return llama_grammar_init_impl(vocab, grammar_str.c_str(), "root");
|
||||
}
|
||||
|
||||
static bool test_build_grammar_fails(const std::string & grammar_str) {
|
||||
|
@ -143,7 +133,7 @@ static void test(const std::string & test_desc, const std::string & grammar_str,
|
|||
}
|
||||
|
||||
// Clean up allocated memory
|
||||
llama_grammar_free(grammar);
|
||||
llama_grammar_free_impl(grammar);
|
||||
}
|
||||
static void test_grammar(const std::string & test_desc, const std::string & grammar_str, const std::vector<std::string> & passing_strings, const std::vector<std::string> & failing_strings) {
|
||||
test(test_desc + ". Grammar: " + grammar_str, grammar_str, passing_strings, failing_strings);
|
||||
|
@ -683,7 +673,8 @@ static void test_failure_missing_root() {
|
|||
term ::= number
|
||||
number ::= [0-9]+)""";
|
||||
|
||||
grammar_parser::parse_state parsed_grammar = grammar_parser::parse(grammar_str.c_str());
|
||||
llama_grammar_parser parsed_grammar;
|
||||
parsed_grammar.parse(grammar_str.c_str());
|
||||
|
||||
// Ensure we parsed correctly
|
||||
assert(!parsed_grammar.rules.empty());
|
||||
|
@ -705,7 +696,8 @@ static void test_failure_missing_reference() {
|
|||
|
||||
fprintf(stderr, " Expected error: ");
|
||||
|
||||
grammar_parser::parse_state parsed_grammar = grammar_parser::parse(grammar_str.c_str());
|
||||
llama_grammar_parser parsed_grammar;
|
||||
parsed_grammar.parse(grammar_str.c_str());
|
||||
|
||||
// Ensure we did NOT parsed correctly
|
||||
assert(parsed_grammar.rules.empty());
|
||||
|
|
|
@ -3,7 +3,7 @@
|
|||
#endif
|
||||
|
||||
#include "llama.h"
|
||||
#include "grammar-parser.h"
|
||||
#include "llama-grammar.h"
|
||||
|
||||
#include <cassert>
|
||||
|
||||
|
@ -22,7 +22,8 @@ static const char * type_str(llama_gretype type) {
|
|||
|
||||
static void verify_parsing(const char *grammar_bytes, const std::vector<std::pair<std::string, uint32_t>> expected, const std::vector<llama_grammar_element> &expected_rules) {
|
||||
uint32_t index = 0;
|
||||
grammar_parser::parse_state parsed_grammar = grammar_parser::parse(grammar_bytes);
|
||||
llama_grammar_parser parsed_grammar;
|
||||
parsed_grammar.parse(grammar_bytes);
|
||||
|
||||
std::map<uint32_t, std::string> symbol_names;
|
||||
for (auto it = parsed_grammar.symbol_ids.begin(); it != parsed_grammar.symbol_ids.end(); ++it) {
|
||||
|
@ -129,9 +130,10 @@ static void verify_parsing(const char *grammar_bytes, const std::vector<std::pai
|
|||
}
|
||||
}
|
||||
|
||||
static void verify_failure(const char *grammar_bytes) {
|
||||
static void verify_failure(const char * grammar_bytes) {
|
||||
fprintf(stderr, "Testing expected failure:%s\n", grammar_bytes);
|
||||
auto result = grammar_parser::parse(grammar_bytes);
|
||||
llama_grammar_parser result;
|
||||
result.parse(grammar_bytes);
|
||||
assert(result.rules.empty() && "should have failed");
|
||||
}
|
||||
|
||||
|
|
|
@ -2,14 +2,15 @@
|
|||
#undef NDEBUG
|
||||
#endif
|
||||
|
||||
#include "json-schema-to-grammar.h"
|
||||
|
||||
#include "llama-grammar.h"
|
||||
|
||||
#include <cassert>
|
||||
#include <fstream>
|
||||
#include <sstream>
|
||||
#include <regex>
|
||||
|
||||
#include "json-schema-to-grammar.h"
|
||||
#include "grammar-parser.h"
|
||||
|
||||
static std::string trim(const std::string & source) {
|
||||
std::string s(source);
|
||||
s.erase(0,s.find_first_not_of(" \n\r\t"));
|
||||
|
@ -40,7 +41,8 @@ struct TestCase {
|
|||
}
|
||||
void verify_expectation_parseable() const {
|
||||
try {
|
||||
auto state = grammar_parser::parse(expected_grammar.c_str());
|
||||
llama_grammar_parser state;
|
||||
state.parse(expected_grammar.c_str());
|
||||
if (state.symbol_ids.find("root") == state.symbol_ids.end()) {
|
||||
throw std::runtime_error("Grammar failed to parse:\n" + expected_grammar);
|
||||
}
|
||||
|
|
|
@ -2,16 +2,16 @@
|
|||
#undef NDEBUG
|
||||
#endif
|
||||
|
||||
#define LLAMA_API_INTERNAL
|
||||
#include "llama.h"
|
||||
#include "grammar-parser.h"
|
||||
#include "llama-vocab.h" // TMP
|
||||
#include "llama-grammar.h"
|
||||
|
||||
#include <cassert>
|
||||
#include <stdexcept>
|
||||
|
||||
int main()
|
||||
{
|
||||
grammar_parser::parse_state parsed_grammar;
|
||||
llama_grammar_parser parsed_grammar;
|
||||
|
||||
std::vector<std::pair<std::string, uint32_t>> expected = {
|
||||
{"expr", 2},
|
||||
|
@ -117,7 +117,8 @@ int main()
|
|||
llama_grammar * grammar = NULL;
|
||||
std::vector<const llama_grammar_element *> grammar_rules(parsed_grammar.c_rules());
|
||||
|
||||
grammar = llama_grammar_init(grammar_rules.data(), grammar_rules.size(), parsed_grammar.symbol_ids.at("root"));
|
||||
llama_vocab vocab; // TMP
|
||||
grammar = llama_grammar_init_impl(vocab, grammar_rules.data(), grammar_rules.size(), parsed_grammar.symbol_ids.at("root"));
|
||||
if (grammar == nullptr)
|
||||
{
|
||||
throw std::runtime_error("Failed to initialize llama_grammar");
|
||||
|
@ -403,6 +404,8 @@ int main()
|
|||
delete[] candidate.code_points;
|
||||
candidate.code_points = nullptr;
|
||||
}
|
||||
llama_grammar_free(grammar);
|
||||
|
||||
llama_grammar_free_impl(grammar);
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
#include "ggml.h"
|
||||
#include "llama.h"
|
||||
#include "llama-sampling.h"
|
||||
|
||||
#ifdef NDEBUG
|
||||
#undef NDEBUG
|
||||
|
@ -20,6 +21,8 @@ static void dump(const llama_token_data_array * candidates) {
|
|||
|
||||
static void test_top_k(const std::vector<float> & probs, const std::vector<float> & expected_probs, int k) {
|
||||
const size_t n_vocab = probs.size();
|
||||
llama_sampling smpl(n_vocab);
|
||||
|
||||
std::vector<llama_token_data> candidates;
|
||||
candidates.reserve(n_vocab);
|
||||
for (llama_token token_id = 0; token_id < (llama_token)n_vocab; token_id++) {
|
||||
|
@ -28,9 +31,9 @@ static void test_top_k(const std::vector<float> & probs, const std::vector<float
|
|||
}
|
||||
|
||||
llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
|
||||
llama_sample_softmax(nullptr, &candidates_p);
|
||||
llama_sampling_softmax_impl(smpl, &candidates_p);
|
||||
DUMP(&candidates_p);
|
||||
llama_sample_top_k(nullptr, &candidates_p, k, 1);
|
||||
llama_sampling_top_k_impl(smpl, &candidates_p, k, 1);
|
||||
DUMP(&candidates_p);
|
||||
|
||||
GGML_ASSERT(candidates_p.size == expected_probs.size());
|
||||
|
@ -41,6 +44,8 @@ static void test_top_k(const std::vector<float> & probs, const std::vector<float
|
|||
|
||||
static void test_top_p(const std::vector<float> & probs, const std::vector<float> & expected_probs, float p) {
|
||||
const size_t n_vocab = probs.size();
|
||||
llama_sampling smpl(n_vocab);
|
||||
|
||||
std::vector<llama_token_data> candidates;
|
||||
candidates.reserve(n_vocab);
|
||||
for (llama_token token_id = 0; token_id < (llama_token)n_vocab; token_id++) {
|
||||
|
@ -49,9 +54,9 @@ static void test_top_p(const std::vector<float> & probs, const std::vector<float
|
|||
}
|
||||
|
||||
llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
|
||||
llama_sample_softmax(nullptr, &candidates_p);
|
||||
llama_sampling_softmax_impl(smpl, &candidates_p);
|
||||
DUMP(&candidates_p);
|
||||
llama_sample_top_p(nullptr, &candidates_p, p, 1);
|
||||
llama_sampling_top_p_impl(smpl, &candidates_p, p, 1);
|
||||
DUMP(&candidates_p);
|
||||
|
||||
GGML_ASSERT(candidates_p.size == expected_probs.size());
|
||||
|
@ -62,6 +67,8 @@ static void test_top_p(const std::vector<float> & probs, const std::vector<float
|
|||
|
||||
static void test_tfs(const std::vector<float> & probs, const std::vector<float> & expected_probs, float z) {
|
||||
const size_t n_vocab = probs.size();
|
||||
llama_sampling smpl(n_vocab);
|
||||
|
||||
std::vector<llama_token_data> candidates;
|
||||
candidates.reserve(n_vocab);
|
||||
for (llama_token token_id = 0; token_id < (llama_token)n_vocab; token_id++) {
|
||||
|
@ -71,7 +78,7 @@ static void test_tfs(const std::vector<float> & probs, const std::vector<float>
|
|||
|
||||
llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
|
||||
DUMP(&candidates_p);
|
||||
llama_sample_tail_free(nullptr, &candidates_p, z, 1);
|
||||
llama_sampling_tail_free_impl(smpl, &candidates_p, z, 1);
|
||||
DUMP(&candidates_p);
|
||||
|
||||
GGML_ASSERT(candidates_p.size == expected_probs.size());
|
||||
|
@ -82,6 +89,8 @@ static void test_tfs(const std::vector<float> & probs, const std::vector<float>
|
|||
|
||||
static void test_min_p(const std::vector<float> & probs, const std::vector<float> & expected_probs, float p) {
|
||||
const size_t n_vocab = probs.size();
|
||||
llama_sampling smpl(n_vocab);
|
||||
|
||||
std::vector<llama_token_data> candidates;
|
||||
candidates.reserve(n_vocab);
|
||||
for (llama_token token_id = 0; token_id < (llama_token)n_vocab; token_id++) {
|
||||
|
@ -91,9 +100,9 @@ static void test_min_p(const std::vector<float> & probs, const std::vector<float
|
|||
|
||||
llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
|
||||
DUMP(&candidates_p);
|
||||
llama_sample_min_p(nullptr, &candidates_p, p, 1);
|
||||
llama_sampling_min_p_impl(smpl, &candidates_p, p, 1);
|
||||
DUMP(&candidates_p);
|
||||
llama_sample_softmax(nullptr, &candidates_p);
|
||||
llama_sampling_softmax_impl(smpl, &candidates_p);
|
||||
|
||||
GGML_ASSERT(candidates_p.size == expected_probs.size());
|
||||
for (size_t i = 0; i < candidates_p.size; i++) {
|
||||
|
@ -103,6 +112,8 @@ static void test_min_p(const std::vector<float> & probs, const std::vector<float
|
|||
|
||||
static void test_typical(const std::vector<float> & probs, const std::vector<float> & expected_probs, float p) {
|
||||
const size_t n_vocab = probs.size();
|
||||
llama_sampling smpl(n_vocab);
|
||||
|
||||
std::vector<llama_token_data> candidates;
|
||||
candidates.reserve(n_vocab);
|
||||
for (llama_token token_id = 0; token_id < (llama_token)n_vocab; token_id++) {
|
||||
|
@ -112,7 +123,7 @@ static void test_typical(const std::vector<float> & probs, const std::vector<flo
|
|||
|
||||
llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
|
||||
DUMP(&candidates_p);
|
||||
llama_sample_typical(nullptr, &candidates_p, p, 1);
|
||||
llama_sampling_typical_impl(smpl, &candidates_p, p, 1);
|
||||
DUMP(&candidates_p);
|
||||
|
||||
GGML_ASSERT(candidates_p.size == expected_probs.size());
|
||||
|
@ -128,6 +139,8 @@ static void test_repetition_penalties(
|
|||
GGML_ASSERT(probs.size() == expected_probs.size());
|
||||
|
||||
const size_t n_vocab = probs.size();
|
||||
llama_sampling smpl(n_vocab);
|
||||
|
||||
std::vector<llama_token_data> candidates;
|
||||
candidates.reserve(n_vocab);
|
||||
for (llama_token token_id = 0; token_id < (llama_token)n_vocab; token_id++) {
|
||||
|
@ -136,10 +149,10 @@ static void test_repetition_penalties(
|
|||
}
|
||||
|
||||
llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
|
||||
llama_sample_softmax(nullptr, &candidates_p);
|
||||
llama_sampling_softmax_impl(smpl, &candidates_p);
|
||||
DUMP(&candidates_p);
|
||||
llama_sample_repetition_penalties(nullptr, &candidates_p, (const llama_token *) last_tokens.data(), last_tokens.size(), repeat_penalty, alpha_frequency, alpha_presence);
|
||||
llama_sample_softmax(nullptr, &candidates_p);
|
||||
llama_sampling_repetition_penalties_impl(smpl, &candidates_p, (const llama_token *) last_tokens.data(), last_tokens.size(), repeat_penalty, alpha_frequency, alpha_presence);
|
||||
llama_sampling_softmax_impl(smpl, &candidates_p);
|
||||
DUMP(&candidates_p);
|
||||
|
||||
GGML_ASSERT(candidates_p.size == expected_probs.size());
|
||||
|
@ -148,9 +161,10 @@ static void test_repetition_penalties(
|
|||
}
|
||||
}
|
||||
|
||||
static void test_sampler_queue(
|
||||
const size_t n_vocab, const std::string samplers_sequence, const int top_k, const float top_p, const float min_p
|
||||
static void test_sampler_queue(const size_t n_vocab, const std::string & samplers_sequence, const int top_k, const float top_p, const float min_p
|
||||
) {
|
||||
llama_sampling smpl(n_vocab);
|
||||
|
||||
std::vector<llama_token_data> candidates;
|
||||
candidates.reserve(n_vocab);
|
||||
for (llama_token token_id = 0; token_id < (llama_token)n_vocab; token_id++) {
|
||||
|
@ -165,16 +179,16 @@ static void test_sampler_queue(
|
|||
|
||||
for (auto s : samplers_sequence) {
|
||||
switch (s){
|
||||
case 'k': llama_sample_top_k (nullptr, &candidates_p, top_k, 1); break;
|
||||
case 'f': GGML_ABORT("tail_free test not implemented"); break;
|
||||
case 'y': GGML_ABORT("typical test not implemented"); break;
|
||||
case 'p': llama_sample_top_p (nullptr, &candidates_p, top_p, 1); break;
|
||||
case 'm': llama_sample_min_p (nullptr, &candidates_p, min_p, 1); break;
|
||||
case 't': GGML_ABORT("temperature test not implemented"); break;
|
||||
default : GGML_ABORT("Unknown sampler"); break;
|
||||
case 'k': llama_sampling_top_k_impl(smpl, &candidates_p, top_k, 1); break;
|
||||
case 'f': GGML_ABORT("tail_free test not implemented");
|
||||
case 'y': GGML_ABORT("typical test not implemented");
|
||||
case 'p': llama_sampling_top_p_impl(smpl, &candidates_p, top_p, 1); break;
|
||||
case 'm': llama_sampling_min_p_impl(smpl, &candidates_p, min_p, 1); break;
|
||||
case 't': GGML_ABORT("temperature test not implemented");
|
||||
default : GGML_ABORT("Unknown sampler");
|
||||
}
|
||||
|
||||
llama_sample_softmax(nullptr, &candidates_p); // make sure tokens are sorted for tests
|
||||
llama_sampling_softmax_impl(smpl, &candidates_p); // make sure tokens are sorted for tests
|
||||
|
||||
const int size = candidates_p.size;
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue