Merge branch 'master' into gguf

This commit is contained in:
Georgi Gerganov 2023-08-18 12:38:05 +03:00
commit 856afff746
No known key found for this signature in database
GPG key ID: 449E073F9DC10735
6 changed files with 1717 additions and 979 deletions

View file

@ -238,12 +238,17 @@ In order to build llama.cpp you have three different options.
cmake --build . --config Release cmake --build . --config Release
``` ```
- Using `Zig`: - Using `Zig` (version 0.11 or later):
Building for optimization levels and CPU features can be accomplished using standard build arguments, for example AVX2, FMA, F16C,
it's also possible to cross compile for other operating systems and architectures:
```bash ```bash
zig build -Doptimize=ReleaseFast zig build -Doptimize=ReleaseFast -Dtarget=x86_64-windows-gnu -Dcpu=x86_64+avx2+fma+f16c
``` ```
The `zig targets` command will give you valid options to use.
- Using `gmake` (FreeBSD): - Using `gmake` (FreeBSD):
1. Install and activate [DRM in FreeBSD](https://wiki.freebsd.org/Graphics) 1. Install and activate [DRM in FreeBSD](https://wiki.freebsd.org/Graphics)
@ -408,7 +413,7 @@ Building the program with BLAS support may lead to some performance improvements
|-------------------------|------------------------|---------|-------------| |-------------------------|------------------------|---------|-------------|
| LLAMA_CUDA_FORCE_DMMV | Boolean | false | Force the use of dequantization + matrix vector multiplication kernels instead of using kernels that do matrix vector multiplication on quantized data. By default the decision is made based on compute capability (MMVQ for 6.1/Pascal/GTX 1000 or higher). Does not affect k-quants. | | LLAMA_CUDA_FORCE_DMMV | Boolean | false | Force the use of dequantization + matrix vector multiplication kernels instead of using kernels that do matrix vector multiplication on quantized data. By default the decision is made based on compute capability (MMVQ for 6.1/Pascal/GTX 1000 or higher). Does not affect k-quants. |
| LLAMA_CUDA_DMMV_X | Positive integer >= 32 | 32 | Number of values in x direction processed by the CUDA dequantization + matrix vector multiplication kernel per iteration. Increasing this value can improve performance on fast GPUs. Power of 2 heavily recommended. Does not affect k-quants. | | LLAMA_CUDA_DMMV_X | Positive integer >= 32 | 32 | Number of values in x direction processed by the CUDA dequantization + matrix vector multiplication kernel per iteration. Increasing this value can improve performance on fast GPUs. Power of 2 heavily recommended. Does not affect k-quants. |
| LLAMA_CUDA_MMV_Y | Positive integer | 1 | Block size in y direction for the CUDA mul mat vec kernels. Increasing this value can improve performance on fast GPUs. Power of 2 recommended. Does not affect k-quants. | | LLAMA_CUDA_MMV_Y | Positive integer | 1 | Block size in y direction for the CUDA mul mat vec kernels. Increasing this value can improve performance on fast GPUs. Power of 2 recommended. |
| LLAMA_CUDA_F16 | Boolean | false | If enabled, use half-precision floating point arithmetic for the CUDA dequantization + mul mat vec kernels and for the q4_1 and q5_1 matrix matrix multiplication kernels. Can improve performance on relatively recent GPUs. | | LLAMA_CUDA_F16 | Boolean | false | If enabled, use half-precision floating point arithmetic for the CUDA dequantization + mul mat vec kernels and for the q4_1 and q5_1 matrix matrix multiplication kernels. Can improve performance on relatively recent GPUs. |
| LLAMA_CUDA_KQUANTS_ITER | 1 or 2 | 2 | Number of values processed per iteration and per CUDA thread for Q2_K and Q6_K quantization formats. Setting this value to 1 can improve performance for slow GPUs. | | LLAMA_CUDA_KQUANTS_ITER | 1 or 2 | 2 | Number of values processed per iteration and per CUDA thread for Q2_K and Q6_K quantization formats. Setting this value to 1 can improve performance for slow GPUs. |

View file

@ -1,5 +1,6 @@
// Compatible with Zig Version 0.11.0 // Compatible with Zig Version 0.11.0
const std = @import("std"); const std = @import("std");
const ArrayList = std.ArrayList;
const Compile = std.Build.Step.Compile; const Compile = std.Build.Step.Compile;
const ConfigHeader = std.Build.Step.ConfigHeader; const ConfigHeader = std.Build.Step.ConfigHeader;
const Mode = std.builtin.Mode; const Mode = std.builtin.Mode;
@ -10,11 +11,31 @@ const Maker = struct {
target: CrossTarget, target: CrossTarget,
optimize: Mode, optimize: Mode,
config_header: *ConfigHeader, config_header: *ConfigHeader,
enable_lto: bool,
const cflags = .{"-std=c11"}; include_dirs: ArrayList([]const u8),
const cxxflags = .{"-std=c++11"}; cflags: ArrayList([]const u8),
cxxflags: ArrayList([]const u8),
objs: ArrayList(*Compile),
fn init(builder: *std.build.Builder) Maker { fn addInclude(m: *Maker, dir: []const u8) !void {
try m.include_dirs.append(dir);
}
fn addProjectInclude(m: *Maker, path: []const []const u8) !void {
try m.addInclude(try m.builder.build_root.join(m.builder.allocator, path));
}
fn addCFlag(m: *Maker, flag: []const u8) !void {
try m.cflags.append(flag);
}
fn addCxxFlag(m: *Maker, flag: []const u8) !void {
try m.cxxflags.append(flag);
}
fn addFlag(m: *Maker, flag: []const u8) !void {
try m.addCFlag(flag);
try m.addCxxFlag(flag);
}
fn init(builder: *std.build.Builder) !Maker {
const commit_hash = @embedFile(".git/refs/heads/master"); const commit_hash = @embedFile(".git/refs/heads/master");
const config_header = builder.addConfigHeader( const config_header = builder.addConfigHeader(
.{ .style = .blank, .include_path = "build-info.h" }, .{ .style = .blank, .include_path = "build-info.h" },
@ -23,58 +44,71 @@ const Maker = struct {
.BUILD_COMMIT = commit_hash[0 .. commit_hash.len - 1], // omit newline .BUILD_COMMIT = commit_hash[0 .. commit_hash.len - 1], // omit newline
}, },
); );
return Maker{ var m = Maker{
.builder = builder, .builder = builder,
.target = builder.standardTargetOptions(.{}), .target = builder.standardTargetOptions(.{}),
.optimize = builder.standardOptimizeOption(.{}), .optimize = builder.standardOptimizeOption(.{}),
.config_header = config_header, .config_header = config_header,
.enable_lto = false,
.include_dirs = ArrayList([]const u8).init(builder.allocator),
.cflags = ArrayList([]const u8).init(builder.allocator),
.cxxflags = ArrayList([]const u8).init(builder.allocator),
.objs = ArrayList(*Compile).init(builder.allocator),
}; };
try m.addCFlag("-std=c11");
try m.addCxxFlag("-std=c++11");
try m.addProjectInclude(&.{});
try m.addProjectInclude(&.{"examples"});
return m;
} }
fn obj(m: *const Maker, name: []const u8, src: []const u8) *Compile { fn obj(m: *const Maker, name: []const u8, src: []const u8) *Compile {
const o = m.builder.addObject(.{ .name = name, .target = m.target, .optimize = m.optimize }); const o = m.builder.addObject(.{ .name = name, .target = m.target, .optimize = m.optimize });
if (std.mem.endsWith(u8, src, ".c")) { if (std.mem.endsWith(u8, src, ".c")) {
o.addCSourceFiles(&.{src}, &cflags); o.addCSourceFiles(&.{src}, m.cflags.items);
o.linkLibC(); o.linkLibC();
} else { } else {
o.addCSourceFiles(&.{src}, &cxxflags); o.addCSourceFiles(&.{src}, m.cxxflags.items);
o.linkLibCpp(); o.linkLibCpp();
} }
o.addIncludePath(.{ .path = "." }); for (m.include_dirs.items) |i| o.addIncludePath(.{ .path = i });
o.addIncludePath(.{ .path = "./examples" }); o.want_lto = m.enable_lto;
return o; return o;
} }
fn exe(m: *const Maker, name: []const u8, src: []const u8, deps: []const *Compile) *Compile { fn exe(m: *const Maker, name: []const u8, src: []const u8, deps: []const *Compile) *Compile {
const e = m.builder.addExecutable(.{ .name = name, .target = m.target, .optimize = m.optimize }); const e = m.builder.addExecutable(.{ .name = name, .target = m.target, .optimize = m.optimize });
e.addIncludePath(.{ .path = "." }); e.addCSourceFiles(&.{src}, m.cxxflags.items);
e.addIncludePath(.{ .path = "./examples" });
e.addCSourceFiles(&.{src}, &cxxflags);
for (deps) |d| e.addObject(d); for (deps) |d| e.addObject(d);
for (m.objs.items) |o| e.addObject(o);
for (m.include_dirs.items) |i| e.addIncludePath(.{ .path = i });
e.linkLibC(); e.linkLibC();
e.linkLibCpp(); e.linkLibCpp();
e.addConfigHeader(m.config_header); e.addConfigHeader(m.config_header);
m.builder.installArtifact(e); m.builder.installArtifact(e);
e.want_lto = m.enable_lto;
// Currently a bug is preventing correct linking for optimized builds for Windows:
// https://github.com/ziglang/zig/issues/15958
if (e.target.isWindows()) {
e.want_lto = false;
}
return e; return e;
} }
}; };
pub fn build(b: *std.build.Builder) void { pub fn build(b: *std.build.Builder) !void {
const make = Maker.init(b); var make = try Maker.init(b);
make.enable_lto = b.option(bool, "lto", "Enable LTO optimization, (default: false)") orelse false;
if (b.option(bool, "k-quants", "Enable K-quants, (default: true)") orelse true) {
try make.addFlag("-DGGML_USE_K_QUANTS");
const k_quants = make.obj("k_quants", "k_quants.c");
try make.objs.append(k_quants);
}
const ggml = make.obj("ggml", "ggml.c"); const ggml = make.obj("ggml", "ggml.c");
const ggml_alloc = make.obj("ggml-alloc", "ggml-alloc.c"); const ggml_alloc = make.obj("ggml-alloc", "ggml-alloc.c");
const llama = make.obj("llama", "llama.cpp"); const llama = make.obj("llama", "llama.cpp");
const common = make.obj("common", "examples/common.cpp"); const common = make.obj("common", "examples/common.cpp");
const console = make.obj("common", "examples/console.cpp");
const grammar_parser = make.obj("grammar-parser", "examples/grammar-parser.cpp"); const grammar_parser = make.obj("grammar-parser", "examples/grammar-parser.cpp");
_ = make.exe("main", "examples/main/main.cpp", &.{ ggml, ggml_alloc, llama, common, grammar_parser }); _ = make.exe("main", "examples/main/main.cpp", &.{ ggml, ggml_alloc, llama, common, console, grammar_parser });
_ = make.exe("quantize", "examples/quantize/quantize.cpp", &.{ ggml, ggml_alloc, llama }); _ = make.exe("quantize", "examples/quantize/quantize.cpp", &.{ ggml, ggml_alloc, llama });
_ = make.exe("perplexity", "examples/perplexity/perplexity.cpp", &.{ ggml, ggml_alloc, llama, common }); _ = make.exe("perplexity", "examples/perplexity/perplexity.cpp", &.{ ggml, ggml_alloc, llama, common });
_ = make.exe("embedding", "examples/embedding/embedding.cpp", &.{ ggml, ggml_alloc, llama, common }); _ = make.exe("embedding", "examples/embedding/embedding.cpp", &.{ ggml, ggml_alloc, llama, common });

File diff suppressed because it is too large Load diff

View file

@ -170,6 +170,136 @@
grammar: '', grammar: '',
}) })
/* START: Support for storing prompt templates and parameters in borwser LocalStorage */
const local_storage_storageKey = "llamacpp_server_local_storage";
function local_storage_setDataFromObject(tag, content) {
localStorage.setItem(local_storage_storageKey + '/' + tag, JSON.stringify(content));
}
function local_storage_setDataFromRawText(tag, content) {
localStorage.setItem(local_storage_storageKey + '/' + tag, content);
}
function local_storage_getDataAsObject(tag) {
const item = localStorage.getItem(local_storage_storageKey + '/' + tag);
if (!item) {
return null;
} else {
return JSON.parse(item);
}
}
function local_storage_getDataAsRawText(tag) {
const item = localStorage.getItem(local_storage_storageKey + '/' + tag);
if (!item) {
return null;
} else {
return item;
}
}
// create a container for user templates and settings
const savedUserTemplates = signal({})
const selectedUserTemplate = signal({ name: '', template: { session: {}, params: {} } })
// let's import locally saved templates and settings if there are any
// user templates and settings are stored in one object
// in form of { "templatename": "templatedata" } and { "settingstemplatename":"settingsdata" }
console.log('Importing saved templates')
let importedTemplates = local_storage_getDataAsObject('user_templates')
if (importedTemplates) {
// saved templates were successfuly imported.
console.log('Processing saved templates and updating default template')
//console.log(importedTemplates);
savedUserTemplates.value = importedTemplates;
//override default template
savedUserTemplates.value.default = { session: session.value, params: params.value }
local_storage_setDataFromObject('user_templates', savedUserTemplates.value)
} else {
// no saved templates detected.
console.log('Initializing LocalStorage and saving default template')
savedUserTemplates.value = { "default": { session: session.value, params: params.value } }
local_storage_setDataFromObject('user_templates', savedUserTemplates.value)
}
function userTemplateResetToDefault() {
console.log('Reseting themplate to default')
selectedUserTemplate.value.name = 'default';
selectedUserTemplate.value.data = savedUserTemplates.value['default'];
}
function userTemplateApply(t) {
session.value = t.data.session;
params.value = t.data.params;
}
function userTemplateResetToDefaultAndApply() {
userTemplateResetToDefault()
userTemplateApply(selectedUserTemplate.value)
}
function userTemplateLoadAndApplyAutosaved() {
// get autosaved last used template
let lastUsedTemplate = local_storage_getDataAsObject('user_templates_last')
if (lastUsedTemplate) {
console.log('Autosaved template found, restoring')
selectedUserTemplate.value = lastUsedTemplate
}
else {
console.log('No autosaved template found, using default template')
// no autosaved last used template was found, so load from default.
userTemplateResetToDefault()
}
console.log('Applying template')
// and update internal data from templates
userTemplateApply(selectedUserTemplate.value)
}
//console.log(savedUserTemplates.value)
//console.log(selectedUserTemplate.value)
function userTemplateAutosave() {
console.log('Template Autosave...')
if (selectedUserTemplate.value.name == 'default') {
// we don't want to save over default template, so let's create a new one
let newTemplateName = 'UserTemplate-' + Date.now().toString()
let newTemplate = { 'name': newTemplateName, 'data': { 'session': session.value, 'params': params.value } }
console.log('Saving as ' + newTemplateName)
// save in the autosave slot
local_storage_setDataFromObject('user_templates_last', newTemplate)
// and load it back and apply
userTemplateLoadAndApplyAutosaved()
} else {
local_storage_setDataFromObject('user_templates_last', { 'name': selectedUserTemplate.value.name, 'data': { 'session': session.value, 'params': params.value } })
}
}
console.log('Checking for autosaved last used template')
userTemplateLoadAndApplyAutosaved()
/* END: Support for storing prompt templates and parameters in browsers LocalStorage */
const llamaStats = signal(null) const llamaStats = signal(null)
const controller = signal(null) const controller = signal(null)
@ -346,8 +476,34 @@
` `
}; };
const userTemplateReset = (e) => {
e.preventDefault();
userTemplateResetToDefaultAndApply()
}
const UserTemplateResetButton = () => {
if (selectedUserTemplate.value.name == 'default') {
return html`
<button disabled>Using default template</button>
`
}
return html`
<button onclick=${userTemplateReset}>Reset all to default</button>
`
};
useEffect(() => {
// autosave template on every change
userTemplateAutosave()
}, [session.value, params.value])
return html` return html`
<form> <form>
<fieldset>
<${UserTemplateResetButton}/>
</fieldset>
<fieldset> <fieldset>
<div> <div>
<label for="prompt">Prompt</label> <label for="prompt">Prompt</label>

153
llama.cpp
View file

@ -2574,37 +2574,81 @@ static std::vector<llama_vocab::id> llama_tokenize(const llama_vocab & vocab, co
// grammar - internal // grammar - internal
// //
struct llama_partial_utf8 {
uint32_t value; // bit value so far (unshifted)
int n_remain; // num bytes remaining; -1 indicates invalid sequence
};
struct llama_grammar { struct llama_grammar {
const std::vector<std::vector<llama_grammar_element>> rules; const std::vector<std::vector<llama_grammar_element>> rules;
std::vector<std::vector<const llama_grammar_element *>> stacks; std::vector<std::vector<const llama_grammar_element *>> stacks;
// buffer for partially generated UTF-8 sequence from accepted tokens
llama_partial_utf8 partial_utf8;
}; };
struct llama_grammar_candidate { struct llama_grammar_candidate {
size_t index; size_t index;
const uint32_t * code_points; const uint32_t * code_points;
llama_partial_utf8 partial_utf8;
}; };
// NOTE: assumes valid utf8 (but checks for overrun) // Decodes a UTF-8 string which may end in an incomplete sequence. Adds a terminating 0 for use as
// adds a terminating 0 for use as pointer // pointer. If an invalid sequence is encountered, returns `llama_partial_utf8.n_remain == -1`.
std::vector<uint32_t> decode_utf8(const char * src) { std::pair<std::vector<uint32_t>, llama_partial_utf8> decode_utf8(
static const int lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 4 }; const char * src,
llama_partial_utf8 partial_start) {
static const int lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 2, 2, 3, 4 };
const char * pos = src; const char * pos = src;
std::vector<uint32_t> code_points; std::vector<uint32_t> code_points;
uint32_t value = partial_start.value;
int n_remain = partial_start.n_remain;
// continue previous decode, if applicable
while (*pos != 0 && n_remain > 0) {
uint8_t next_byte = static_cast<uint8_t>(*pos);
if ((next_byte >> 6) != 2) {
// invalid sequence, abort
code_points.push_back(0);
return std::make_pair(std::move(code_points), llama_partial_utf8{ 0, -1 });
}
value = (value << 6) + (next_byte & 0x3F);
++pos;
--n_remain;
}
if (partial_start.n_remain > 0 && n_remain == 0) {
code_points.push_back(value);
}
// decode any subsequent utf-8 sequences, which may end in an incomplete one
while (*pos != 0) { while (*pos != 0) {
uint8_t first_byte = static_cast<uint8_t>(*pos); uint8_t first_byte = static_cast<uint8_t>(*pos);
uint8_t highbits = first_byte >> 4; uint8_t highbits = first_byte >> 4;
int len = lookup[highbits]; n_remain = lookup[highbits] - 1;
uint8_t mask = (1 << (8 - len)) - 1;
uint32_t value = first_byte & mask; if (n_remain < 0) {
const char * end = pos + len; // may overrun! // invalid sequence, abort
++pos; code_points.clear();
for ( ; pos < end && *pos != 0; ++pos) { code_points.push_back(0);
value = (value << 6) + (static_cast<uint8_t>(*pos) & 0x3F); return std::make_pair(std::move(code_points), llama_partial_utf8{ 0, n_remain });
} }
uint8_t mask = (1 << (7 - n_remain)) - 1;
value = first_byte & mask;
++pos;
while (*pos != 0 && n_remain > 0) {
value = (value << 6) + (static_cast<uint8_t>(*pos) & 0x3F);
++pos;
--n_remain;
}
if (n_remain == 0) {
code_points.push_back(value); code_points.push_back(value);
} }
}
code_points.push_back(0); code_points.push_back(0);
return code_points;
return std::make_pair(std::move(code_points), llama_partial_utf8{ value, n_remain });
} }
// returns true iff pos points to the end of one of the definitions of a rule // returns true iff pos points to the end of one of the definitions of a rule
@ -2641,6 +2685,56 @@ static std::pair<bool, const llama_grammar_element *> llama_grammar_match_char(
return std::make_pair(found == is_positive_char, pos); return std::make_pair(found == is_positive_char, pos);
} }
// returns true iff some continuation of the given partial UTF-8 sequence could satisfy the char
// range at pos (regular or inverse range)
// asserts that pos is pointing to a char range element
static bool llama_grammar_match_partial_char(
const llama_grammar_element * pos,
const llama_partial_utf8 partial_utf8) {
bool is_positive_char = pos->type == LLAMA_GRETYPE_CHAR;
GGML_ASSERT(is_positive_char || pos->type == LLAMA_GRETYPE_CHAR_NOT);
uint32_t partial_value = partial_utf8.value;
int n_remain = partial_utf8.n_remain;
// invalid sequence or 7-bit char split across 2 bytes (overlong)
if (n_remain < 0 || (n_remain == 1 && partial_value < 2)) {
return false;
}
// range of possible code points this partial UTF-8 sequence could complete to
uint32_t low = partial_value << (n_remain * 6);
uint32_t high = low | ((1 << (n_remain * 6)) - 1);
if (low == 0) {
if (n_remain == 2) {
low = 1 << 11;
} else if (n_remain == 3) {
low = 1 << 16;
}
}
do {
if (pos[1].type == LLAMA_GRETYPE_CHAR_RNG_UPPER) {
// inclusive range, e.g. [a-z]
if (pos->value <= high && low <= pos[1].value) {
return is_positive_char;
}
pos += 2;
} else {
// exact char match, e.g. [a] or "a"
if (low <= pos->value && pos->value <= high) {
return is_positive_char;
}
pos += 1;
}
} while (pos->type == LLAMA_GRETYPE_CHAR_ALT);
return !is_positive_char;
}
// transforms a grammar pushdown stack into N possible stacks, all ending // transforms a grammar pushdown stack into N possible stacks, all ending
// at a character range (terminal element) // at a character range (terminal element)
static void llama_grammar_advance_stack( static void llama_grammar_advance_stack(
@ -2741,8 +2835,11 @@ static std::vector<llama_grammar_candidate> llama_grammar_reject_candidates_for_
std::vector<llama_grammar_candidate> rejects; std::vector<llama_grammar_candidate> rejects;
if (stack.empty()) { if (stack.empty()) {
// accept nothing; EOS is handled elsewhere for (auto tok : candidates) {
rejects.insert(rejects.end(), candidates.begin(), candidates.end()); if (*tok.code_points != 0 || tok.partial_utf8.n_remain != 0) {
rejects.push_back(tok);
}
}
return rejects; return rejects;
} }
@ -2750,10 +2847,15 @@ static std::vector<llama_grammar_candidate> llama_grammar_reject_candidates_for_
std::vector<llama_grammar_candidate> next_candidates; std::vector<llama_grammar_candidate> next_candidates;
for (auto tok : candidates) { for (auto tok : candidates) {
if (llama_grammar_match_char(stack_pos, tok.code_points[0]).first) { if (*tok.code_points == 0) {
if (tok.code_points[1] != 0) { // reached end of full codepoints in token, reject iff it ended in a partial sequence
next_candidates.push_back({ tok.index, tok.code_points + 1 }); // that cannot satisfy this position in grammar
if (tok.partial_utf8.n_remain != 0 &&
!llama_grammar_match_partial_char(stack_pos, tok.partial_utf8)) {
rejects.push_back(tok);
} }
} else if (llama_grammar_match_char(stack_pos, *tok.code_points).first) {
next_candidates.push_back({ tok.index, tok.code_points + 1, tok.partial_utf8 });
} else { } else {
rejects.push_back(tok); rejects.push_back(tok);
} }
@ -2771,7 +2873,7 @@ static std::vector<llama_grammar_candidate> llama_grammar_reject_candidates_for_
auto next_rejects = llama_grammar_reject_candidates(rules, next_stacks, next_candidates); auto next_rejects = llama_grammar_reject_candidates(rules, next_stacks, next_candidates);
for (auto tok : next_rejects) { for (auto tok : next_rejects) {
rejects.push_back({ tok.index, tok.code_points - 1 }); rejects.push_back({ tok.index, tok.code_points - 1, tok.partial_utf8 });
} }
return rejects; return rejects;
@ -2836,7 +2938,7 @@ struct llama_grammar * llama_grammar_init(
} }
} while (true); } while (true);
return new llama_grammar{ std::move(vec_rules), std::move(stacks) }; return new llama_grammar{ std::move(vec_rules), std::move(stacks), {} };
} }
void llama_grammar_free(struct llama_grammar * grammar) { void llama_grammar_free(struct llama_grammar * grammar) {
@ -3141,7 +3243,7 @@ void llama_sample_grammar(struct llama_context * ctx, llama_token_data_array * c
const llama_token eos = llama_token_eos(); const llama_token eos = llama_token_eos();
std::vector<std::vector<uint32_t>> candidates_decoded; std::vector<std::pair<std::vector<uint32_t>, llama_partial_utf8>> candidates_decoded;
std::vector<llama_grammar_candidate> candidates_grammar; std::vector<llama_grammar_candidate> candidates_grammar;
for (size_t i = 0; i < candidates->size; ++i) { for (size_t i = 0; i < candidates->size; ++i) {
@ -3154,8 +3256,8 @@ void llama_sample_grammar(struct llama_context * ctx, llama_token_data_array * c
} else if (str.empty()) { } else if (str.empty()) {
candidates->data[i].logit = -INFINITY; candidates->data[i].logit = -INFINITY;
} else { } else {
candidates_decoded.push_back(decode_utf8(str.c_str())); candidates_decoded.push_back(decode_utf8(str.c_str(), grammar->partial_utf8));
candidates_grammar.push_back({ i, candidates_decoded.back().data() }); candidates_grammar.push_back({ i, candidates_decoded.back().first.data(), candidates_decoded.back().second });
} }
} }
@ -3354,12 +3456,15 @@ void llama_grammar_accept_token(struct llama_context * ctx, struct llama_grammar
GGML_ASSERT(false); GGML_ASSERT(false);
} }
std::string str = llama_token_to_str(ctx, token); const std::string str = llama_token_to_str(ctx, token);
// Note terminating 0 in decoded string // Note terminating 0 in decoded string
auto code_points = decode_utf8(str.c_str()); const auto decoded = decode_utf8(str.c_str(), grammar->partial_utf8);
const auto & code_points = decoded.first;
for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) { for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) {
grammar->stacks = llama_grammar_accept(grammar->rules, grammar->stacks, *it); grammar->stacks = llama_grammar_accept(grammar->rules, grammar->stacks, *it);
} }
grammar->partial_utf8 = decoded.second;
GGML_ASSERT(!grammar->stacks.empty()); GGML_ASSERT(!grammar->stacks.empty());
ctx->t_sample_us += ggml_time_us() - t_start_sample_us; ctx->t_sample_us += ggml_time_us() - t_start_sample_us;

View file

@ -199,7 +199,7 @@ int main()
uint32_t *cp = new uint32_t[2]; // dynamically allocate memory for code_point uint32_t *cp = new uint32_t[2]; // dynamically allocate memory for code_point
cp[0] = 37 + i; cp[0] = 37 + i;
cp[1] = 0; cp[1] = 0;
next_candidates[i] = {i, cp}; next_candidates[i] = {i, cp, {}};
} }
std::vector<std::vector<std::pair<uint32_t, uint16_t>>> expected_reject = { std::vector<std::vector<std::pair<uint32_t, uint16_t>>> expected_reject = {