Merge branch 'master' into server-probs

This commit is contained in:
jhen 2023-08-18 07:38:35 +08:00
commit 5c6aee64de
17 changed files with 3121 additions and 2112 deletions

View file

@ -296,7 +296,6 @@ if (LLAMA_METAL)
find_library(FOUNDATION_LIBRARY Foundation REQUIRED) find_library(FOUNDATION_LIBRARY Foundation REQUIRED)
find_library(METAL_FRAMEWORK Metal REQUIRED) find_library(METAL_FRAMEWORK Metal REQUIRED)
find_library(METALKIT_FRAMEWORK MetalKit REQUIRED) find_library(METALKIT_FRAMEWORK MetalKit REQUIRED)
find_library(METALPERFORMANCE_FRAMEWORK MetalPerformanceShaders REQUIRED)
set(GGML_SOURCES_METAL ggml-metal.m ggml-metal.h) set(GGML_SOURCES_METAL ggml-metal.m ggml-metal.h)
@ -313,7 +312,6 @@ if (LLAMA_METAL)
${FOUNDATION_LIBRARY} ${FOUNDATION_LIBRARY}
${METAL_FRAMEWORK} ${METAL_FRAMEWORK}
${METALKIT_FRAMEWORK} ${METALKIT_FRAMEWORK}
${METALPERFORMANCE_FRAMEWORK}
) )
endif() endif()
@ -571,6 +569,16 @@ install(
WORLD_READ WORLD_READ
WORLD_EXECUTE WORLD_EXECUTE
DESTINATION ${CMAKE_INSTALL_BINDIR}) DESTINATION ${CMAKE_INSTALL_BINDIR})
if (LLAMA_METAL)
install(
FILES ggml-metal.metal
PERMISSIONS
OWNER_READ
OWNER_WRITE
GROUP_READ
WORLD_READ
DESTINATION ${CMAKE_INSTALL_BINDIR})
endif()
# #
# programs, examples and tests # programs, examples and tests

View file

@ -2,7 +2,7 @@
BUILD_TARGETS = main quantize quantize-stats perplexity embedding vdot train-text-from-scratch convert-llama2c-to-ggml simple server embd-input-test BUILD_TARGETS = main quantize quantize-stats perplexity embedding vdot train-text-from-scratch convert-llama2c-to-ggml simple server embd-input-test
# Binaries only useful for tests # Binaries only useful for tests
TEST_TARGETS = tests/test-grammar-parser tests/test-double-float tests/test-grad0 tests/test-opt tests/test-quantize-fns tests/test-quantize-perf tests/test-sampling tests/test-tokenizer-0 TEST_TARGETS = tests/test-llama-grammar tests/test-grammar-parser tests/test-double-float tests/test-grad0 tests/test-opt tests/test-quantize-fns tests/test-quantize-perf tests/test-sampling tests/test-tokenizer-0
default: $(BUILD_TARGETS) default: $(BUILD_TARGETS)
@ -283,7 +283,7 @@ endif # LLAMA_CLBLAST
ifdef LLAMA_METAL ifdef LLAMA_METAL
CFLAGS += -DGGML_USE_METAL -DGGML_METAL_NDEBUG CFLAGS += -DGGML_USE_METAL -DGGML_METAL_NDEBUG
CXXFLAGS += -DGGML_USE_METAL CXXFLAGS += -DGGML_USE_METAL
LDFLAGS += -framework Foundation -framework Metal -framework MetalKit -framework MetalPerformanceShaders LDFLAGS += -framework Foundation -framework Metal -framework MetalKit
OBJS += ggml-metal.o OBJS += ggml-metal.o
endif # LLAMA_METAL endif # LLAMA_METAL
@ -412,6 +412,9 @@ benchmark-matmult: examples/benchmark/benchmark-matmult.cpp build-info.h ggml.o
vdot: pocs/vdot/vdot.cpp ggml.o $(OBJS) vdot: pocs/vdot/vdot.cpp ggml.o $(OBJS)
$(CXX) $(CXXFLAGS) $^ -o $@ $(LDFLAGS) $(CXX) $(CXXFLAGS) $^ -o $@ $(LDFLAGS)
tests/test-llama-grammar: tests/test-llama-grammar.cpp build-info.h ggml.o llama.o common.o $(OBJS)
$(CXX) $(CXXFLAGS) $(filter-out %.txt,$^) -o $@ $(LDFLAGS)
tests/test-grammar-parser: tests/test-grammar-parser.cpp examples/grammar-parser.cpp build-info.h ggml.o llama.o common.o $(OBJS) tests/test-grammar-parser: tests/test-grammar-parser.cpp examples/grammar-parser.cpp build-info.h ggml.o llama.o common.o $(OBJS)
$(CXX) $(CXXFLAGS) $(filter-out %.txt,$^) -o $@ $(LDFLAGS) $(CXX) $(CXXFLAGS) $(filter-out %.txt,$^) -o $@ $(LDFLAGS)

View file

@ -238,12 +238,17 @@ In order to build llama.cpp you have three different options.
cmake --build . --config Release cmake --build . --config Release
``` ```
- Using `Zig`: - Using `Zig` (version 0.11 or later):
Building for optimization levels and CPU features can be accomplished using standard build arguments, for example AVX2, FMA, F16C,
it's also possible to cross compile for other operating systems and architectures:
```bash ```bash
zig build -Doptimize=ReleaseFast zig build -Doptimize=ReleaseFast -Dtarget=x86_64-windows-gnu -Dcpu=x86_64+avx2+fma+f16c
``` ```
The `zig targets` command will give you valid options to use.
- Using `gmake` (FreeBSD): - Using `gmake` (FreeBSD):
1. Install and activate [DRM in FreeBSD](https://wiki.freebsd.org/Graphics) 1. Install and activate [DRM in FreeBSD](https://wiki.freebsd.org/Graphics)
@ -408,7 +413,7 @@ Building the program with BLAS support may lead to some performance improvements
|-------------------------|------------------------|---------|-------------| |-------------------------|------------------------|---------|-------------|
| LLAMA_CUDA_FORCE_DMMV | Boolean | false | Force the use of dequantization + matrix vector multiplication kernels instead of using kernels that do matrix vector multiplication on quantized data. By default the decision is made based on compute capability (MMVQ for 6.1/Pascal/GTX 1000 or higher). Does not affect k-quants. | | LLAMA_CUDA_FORCE_DMMV | Boolean | false | Force the use of dequantization + matrix vector multiplication kernels instead of using kernels that do matrix vector multiplication on quantized data. By default the decision is made based on compute capability (MMVQ for 6.1/Pascal/GTX 1000 or higher). Does not affect k-quants. |
| LLAMA_CUDA_DMMV_X | Positive integer >= 32 | 32 | Number of values in x direction processed by the CUDA dequantization + matrix vector multiplication kernel per iteration. Increasing this value can improve performance on fast GPUs. Power of 2 heavily recommended. Does not affect k-quants. | | LLAMA_CUDA_DMMV_X | Positive integer >= 32 | 32 | Number of values in x direction processed by the CUDA dequantization + matrix vector multiplication kernel per iteration. Increasing this value can improve performance on fast GPUs. Power of 2 heavily recommended. Does not affect k-quants. |
| LLAMA_CUDA_MMV_Y | Positive integer | 1 | Block size in y direction for the CUDA mul mat vec kernels. Increasing this value can improve performance on fast GPUs. Power of 2 recommended. Does not affect k-quants. | | LLAMA_CUDA_MMV_Y | Positive integer | 1 | Block size in y direction for the CUDA mul mat vec kernels. Increasing this value can improve performance on fast GPUs. Power of 2 recommended. |
| LLAMA_CUDA_F16 | Boolean | false | If enabled, use half-precision floating point arithmetic for the CUDA dequantization + mul mat vec kernels and for the q4_1 and q5_1 matrix matrix multiplication kernels. Can improve performance on relatively recent GPUs. | | LLAMA_CUDA_F16 | Boolean | false | If enabled, use half-precision floating point arithmetic for the CUDA dequantization + mul mat vec kernels and for the q4_1 and q5_1 matrix matrix multiplication kernels. Can improve performance on relatively recent GPUs. |
| LLAMA_CUDA_KQUANTS_ITER | 1 or 2 | 2 | Number of values processed per iteration and per CUDA thread for Q2_K and Q6_K quantization formats. Setting this value to 1 can improve performance for slow GPUs. | | LLAMA_CUDA_KQUANTS_ITER | 1 or 2 | 2 | Number of values processed per iteration and per CUDA thread for Q2_K and Q6_K quantization formats. Setting this value to 1 can improve performance for slow GPUs. |

View file

@ -1,5 +1,6 @@
// Compatible with Zig Version 0.11.0 // Compatible with Zig Version 0.11.0
const std = @import("std"); const std = @import("std");
const ArrayList = std.ArrayList;
const Compile = std.Build.Step.Compile; const Compile = std.Build.Step.Compile;
const ConfigHeader = std.Build.Step.ConfigHeader; const ConfigHeader = std.Build.Step.ConfigHeader;
const Mode = std.builtin.Mode; const Mode = std.builtin.Mode;
@ -10,11 +11,31 @@ const Maker = struct {
target: CrossTarget, target: CrossTarget,
optimize: Mode, optimize: Mode,
config_header: *ConfigHeader, config_header: *ConfigHeader,
enable_lto: bool,
const cflags = .{"-std=c11"}; include_dirs: ArrayList([]const u8),
const cxxflags = .{"-std=c++11"}; cflags: ArrayList([]const u8),
cxxflags: ArrayList([]const u8),
objs: ArrayList(*Compile),
fn init(builder: *std.build.Builder) Maker { fn addInclude(m: *Maker, dir: []const u8) !void {
try m.include_dirs.append(dir);
}
fn addProjectInclude(m: *Maker, path: []const []const u8) !void {
try m.addInclude(try m.builder.build_root.join(m.builder.allocator, path));
}
fn addCFlag(m: *Maker, flag: []const u8) !void {
try m.cflags.append(flag);
}
fn addCxxFlag(m: *Maker, flag: []const u8) !void {
try m.cxxflags.append(flag);
}
fn addFlag(m: *Maker, flag: []const u8) !void {
try m.addCFlag(flag);
try m.addCxxFlag(flag);
}
fn init(builder: *std.build.Builder) !Maker {
const commit_hash = @embedFile(".git/refs/heads/master"); const commit_hash = @embedFile(".git/refs/heads/master");
const config_header = builder.addConfigHeader( const config_header = builder.addConfigHeader(
.{ .style = .blank, .include_path = "build-info.h" }, .{ .style = .blank, .include_path = "build-info.h" },
@ -23,58 +44,71 @@ const Maker = struct {
.BUILD_COMMIT = commit_hash[0 .. commit_hash.len - 1], // omit newline .BUILD_COMMIT = commit_hash[0 .. commit_hash.len - 1], // omit newline
}, },
); );
return Maker{ var m = Maker{
.builder = builder, .builder = builder,
.target = builder.standardTargetOptions(.{}), .target = builder.standardTargetOptions(.{}),
.optimize = builder.standardOptimizeOption(.{}), .optimize = builder.standardOptimizeOption(.{}),
.config_header = config_header, .config_header = config_header,
.enable_lto = false,
.include_dirs = ArrayList([]const u8).init(builder.allocator),
.cflags = ArrayList([]const u8).init(builder.allocator),
.cxxflags = ArrayList([]const u8).init(builder.allocator),
.objs = ArrayList(*Compile).init(builder.allocator),
}; };
try m.addCFlag("-std=c11");
try m.addCxxFlag("-std=c++11");
try m.addProjectInclude(&.{});
try m.addProjectInclude(&.{"examples"});
return m;
} }
fn obj(m: *const Maker, name: []const u8, src: []const u8) *Compile { fn obj(m: *const Maker, name: []const u8, src: []const u8) *Compile {
const o = m.builder.addObject(.{ .name = name, .target = m.target, .optimize = m.optimize }); const o = m.builder.addObject(.{ .name = name, .target = m.target, .optimize = m.optimize });
if (std.mem.endsWith(u8, src, ".c")) { if (std.mem.endsWith(u8, src, ".c")) {
o.addCSourceFiles(&.{src}, &cflags); o.addCSourceFiles(&.{src}, m.cflags.items);
o.linkLibC(); o.linkLibC();
} else { } else {
o.addCSourceFiles(&.{src}, &cxxflags); o.addCSourceFiles(&.{src}, m.cxxflags.items);
o.linkLibCpp(); o.linkLibCpp();
} }
o.addIncludePath(.{ .path = "." }); for (m.include_dirs.items) |i| o.addIncludePath(.{ .path = i });
o.addIncludePath(.{ .path = "./examples" }); o.want_lto = m.enable_lto;
return o; return o;
} }
fn exe(m: *const Maker, name: []const u8, src: []const u8, deps: []const *Compile) *Compile { fn exe(m: *const Maker, name: []const u8, src: []const u8, deps: []const *Compile) *Compile {
const e = m.builder.addExecutable(.{ .name = name, .target = m.target, .optimize = m.optimize }); const e = m.builder.addExecutable(.{ .name = name, .target = m.target, .optimize = m.optimize });
e.addIncludePath(.{ .path = "." }); e.addCSourceFiles(&.{src}, m.cxxflags.items);
e.addIncludePath(.{ .path = "./examples" });
e.addCSourceFiles(&.{src}, &cxxflags);
for (deps) |d| e.addObject(d); for (deps) |d| e.addObject(d);
for (m.objs.items) |o| e.addObject(o);
for (m.include_dirs.items) |i| e.addIncludePath(.{ .path = i });
e.linkLibC(); e.linkLibC();
e.linkLibCpp(); e.linkLibCpp();
e.addConfigHeader(m.config_header); e.addConfigHeader(m.config_header);
m.builder.installArtifact(e); m.builder.installArtifact(e);
e.want_lto = m.enable_lto;
// Currently a bug is preventing correct linking for optimized builds for Windows:
// https://github.com/ziglang/zig/issues/15958
if (e.target.isWindows()) {
e.want_lto = false;
}
return e; return e;
} }
}; };
pub fn build(b: *std.build.Builder) void { pub fn build(b: *std.build.Builder) !void {
const make = Maker.init(b); var make = try Maker.init(b);
make.enable_lto = b.option(bool, "lto", "Enable LTO optimization, (default: false)") orelse false;
if (b.option(bool, "k-quants", "Enable K-quants, (default: true)") orelse true) {
try make.addFlag("-DGGML_USE_K_QUANTS");
const k_quants = make.obj("k_quants", "k_quants.c");
try make.objs.append(k_quants);
}
const ggml = make.obj("ggml", "ggml.c"); const ggml = make.obj("ggml", "ggml.c");
const ggml_alloc = make.obj("ggml-alloc", "ggml-alloc.c"); const ggml_alloc = make.obj("ggml-alloc", "ggml-alloc.c");
const llama = make.obj("llama", "llama.cpp"); const llama = make.obj("llama", "llama.cpp");
const common = make.obj("common", "examples/common.cpp"); const common = make.obj("common", "examples/common.cpp");
const console = make.obj("common", "examples/console.cpp");
const grammar_parser = make.obj("grammar-parser", "examples/grammar-parser.cpp"); const grammar_parser = make.obj("grammar-parser", "examples/grammar-parser.cpp");
_ = make.exe("main", "examples/main/main.cpp", &.{ ggml, ggml_alloc, llama, common, grammar_parser }); _ = make.exe("main", "examples/main/main.cpp", &.{ ggml, ggml_alloc, llama, common, console, grammar_parser });
_ = make.exe("quantize", "examples/quantize/quantize.cpp", &.{ ggml, ggml_alloc, llama }); _ = make.exe("quantize", "examples/quantize/quantize.cpp", &.{ ggml, ggml_alloc, llama });
_ = make.exe("perplexity", "examples/perplexity/perplexity.cpp", &.{ ggml, ggml_alloc, llama, common }); _ = make.exe("perplexity", "examples/perplexity/perplexity.cpp", &.{ ggml, ggml_alloc, llama, common });
_ = make.exe("embedding", "examples/embedding/embedding.cpp", &.{ ggml, ggml_alloc, llama, common }); _ = make.exe("embedding", "examples/embedding/embedding.cpp", &.{ ggml, ggml_alloc, llama, common });

View file

@ -274,6 +274,21 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
break; break;
} }
params.cfg_negative_prompt = argv[i]; params.cfg_negative_prompt = argv[i];
} else if (arg == "--cfg-negative-prompt-file") {
if (++i >= argc) {
invalid_param = true;
break;
}
std::ifstream file(argv[i]);
if (!file) {
fprintf(stderr, "error: failed to open file '%s'\n", argv[i]);
invalid_param = true;
break;
}
std::copy(std::istreambuf_iterator<char>(file), std::istreambuf_iterator<char>(), back_inserter(params.cfg_negative_prompt));
if (params.cfg_negative_prompt.back() == '\n') {
params.cfg_negative_prompt.pop_back();
}
} else if (arg == "--cfg-scale") { } else if (arg == "--cfg-scale") {
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;
@ -567,8 +582,10 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
fprintf(stdout, " or `--logit-bias 15043-1` to decrease likelihood of token ' Hello'\n"); fprintf(stdout, " or `--logit-bias 15043-1` to decrease likelihood of token ' Hello'\n");
fprintf(stdout, " --grammar GRAMMAR BNF-like grammar to constrain generations (see samples in grammars/ dir)\n"); fprintf(stdout, " --grammar GRAMMAR BNF-like grammar to constrain generations (see samples in grammars/ dir)\n");
fprintf(stdout, " --grammar-file FNAME file to read grammar from\n"); fprintf(stdout, " --grammar-file FNAME file to read grammar from\n");
fprintf(stdout, " --cfg-negative-prompt PROMPT \n"); fprintf(stdout, " --cfg-negative-prompt PROMPT\n");
fprintf(stdout, " negative prompt to use for guidance. (default: empty)\n"); fprintf(stdout, " negative prompt to use for guidance. (default: empty)\n");
fprintf(stdout, " --cfg-negative-prompt-file FNAME\n");
fprintf(stdout, " negative prompt file to use for guidance. (default: empty)\n");
fprintf(stdout, " --cfg-scale N strength of guidance (default: %f, 1.0 = disable)\n", params.cfg_scale); fprintf(stdout, " --cfg-scale N strength of guidance (default: %f, 1.0 = disable)\n", params.cfg_scale);
fprintf(stdout, " --rope-scale N RoPE context linear scaling factor, inverse of --rope-freq-scale (default: %g)\n", 1.0f/params.rope_freq_scale); fprintf(stdout, " --rope-scale N RoPE context linear scaling factor, inverse of --rope-freq-scale (default: %g)\n", 1.0f/params.rope_freq_scale);
fprintf(stdout, " --rope-freq-base N RoPE base frequency, used by NTK-aware scaling (default: %.1f)\n", params.rope_freq_base); fprintf(stdout, " --rope-freq-base N RoPE base frequency, used by NTK-aware scaling (default: %.1f)\n", params.rope_freq_base);

File diff suppressed because it is too large Load diff

View file

@ -188,6 +188,136 @@
n_probs: 0, // no completion_probabilities n_probs: 0, // no completion_probabilities
}) })
/* START: Support for storing prompt templates and parameters in borwser LocalStorage */
const local_storage_storageKey = "llamacpp_server_local_storage";
function local_storage_setDataFromObject(tag, content) {
localStorage.setItem(local_storage_storageKey + '/' + tag, JSON.stringify(content));
}
function local_storage_setDataFromRawText(tag, content) {
localStorage.setItem(local_storage_storageKey + '/' + tag, content);
}
function local_storage_getDataAsObject(tag) {
const item = localStorage.getItem(local_storage_storageKey + '/' + tag);
if (!item) {
return null;
} else {
return JSON.parse(item);
}
}
function local_storage_getDataAsRawText(tag) {
const item = localStorage.getItem(local_storage_storageKey + '/' + tag);
if (!item) {
return null;
} else {
return item;
}
}
// create a container for user templates and settings
const savedUserTemplates = signal({})
const selectedUserTemplate = signal({ name: '', template: { session: {}, params: {} } })
// let's import locally saved templates and settings if there are any
// user templates and settings are stored in one object
// in form of { "templatename": "templatedata" } and { "settingstemplatename":"settingsdata" }
console.log('Importing saved templates')
let importedTemplates = local_storage_getDataAsObject('user_templates')
if (importedTemplates) {
// saved templates were successfuly imported.
console.log('Processing saved templates and updating default template')
//console.log(importedTemplates);
savedUserTemplates.value = importedTemplates;
//override default template
savedUserTemplates.value.default = { session: session.value, params: params.value }
local_storage_setDataFromObject('user_templates', savedUserTemplates.value)
} else {
// no saved templates detected.
console.log('Initializing LocalStorage and saving default template')
savedUserTemplates.value = { "default": { session: session.value, params: params.value } }
local_storage_setDataFromObject('user_templates', savedUserTemplates.value)
}
function userTemplateResetToDefault() {
console.log('Reseting themplate to default')
selectedUserTemplate.value.name = 'default';
selectedUserTemplate.value.data = savedUserTemplates.value['default'];
}
function userTemplateApply(t) {
session.value = t.data.session;
params.value = t.data.params;
}
function userTemplateResetToDefaultAndApply() {
userTemplateResetToDefault()
userTemplateApply(selectedUserTemplate.value)
}
function userTemplateLoadAndApplyAutosaved() {
// get autosaved last used template
let lastUsedTemplate = local_storage_getDataAsObject('user_templates_last')
if (lastUsedTemplate) {
console.log('Autosaved template found, restoring')
selectedUserTemplate.value = lastUsedTemplate
}
else {
console.log('No autosaved template found, using default template')
// no autosaved last used template was found, so load from default.
userTemplateResetToDefault()
}
console.log('Applying template')
// and update internal data from templates
userTemplateApply(selectedUserTemplate.value)
}
//console.log(savedUserTemplates.value)
//console.log(selectedUserTemplate.value)
function userTemplateAutosave() {
console.log('Template Autosave...')
if (selectedUserTemplate.value.name == 'default') {
// we don't want to save over default template, so let's create a new one
let newTemplateName = 'UserTemplate-' + Date.now().toString()
let newTemplate = { 'name': newTemplateName, 'data': { 'session': session.value, 'params': params.value } }
console.log('Saving as ' + newTemplateName)
// save in the autosave slot
local_storage_setDataFromObject('user_templates_last', newTemplate)
// and load it back and apply
userTemplateLoadAndApplyAutosaved()
} else {
local_storage_setDataFromObject('user_templates_last', { 'name': selectedUserTemplate.value.name, 'data': { 'session': session.value, 'params': params.value } })
}
}
console.log('Checking for autosaved last used template')
userTemplateLoadAndApplyAutosaved()
/* END: Support for storing prompt templates and parameters in browsers LocalStorage */
const llamaStats = signal(null) const llamaStats = signal(null)
const controller = signal(null) const controller = signal(null)
@ -382,8 +512,34 @@
` `
}; };
const userTemplateReset = (e) => {
e.preventDefault();
userTemplateResetToDefaultAndApply()
}
const UserTemplateResetButton = () => {
if (selectedUserTemplate.value.name == 'default') {
return html`
<button disabled>Using default template</button>
`
}
return html`
<button onclick=${userTemplateReset}>Reset all to default</button>
`
};
useEffect(() => {
// autosave template on every change
userTemplateAutosave()
}, [session.value, params.value])
return html` return html`
<form> <form>
<fieldset>
<${UserTemplateResetButton}/>
</fieldset>
<fieldset> <fieldset>
<div> <div>
<label for="prompt">Prompt</label> <label for="prompt">Prompt</label>

View file

@ -14,8 +14,6 @@
with pkgs.darwin.apple_sdk_11_0.frameworks; [ with pkgs.darwin.apple_sdk_11_0.frameworks; [
Accelerate Accelerate
MetalKit MetalKit
MetalPerformanceShaders
MetalPerformanceShadersGraph
] ]
else if isAarch32 && isDarwin then else if isAarch32 && isDarwin then
with pkgs.darwin.apple_sdk.frameworks; [ with pkgs.darwin.apple_sdk.frameworks; [

View file

@ -67,6 +67,8 @@ struct ggml_allocr {
struct hash_node hash_table[GGML_GRAPH_HASHTABLE_SIZE]; struct hash_node hash_table[GGML_GRAPH_HASHTABLE_SIZE];
size_t max_size; size_t max_size;
bool measure; bool measure;
int parse_seq[GGML_MAX_NODES];
bool has_parse_seq;
#ifdef GGML_ALLOCATOR_DEBUG #ifdef GGML_ALLOCATOR_DEBUG
struct ggml_tensor * allocated_tensors[1024]; struct ggml_tensor * allocated_tensors[1024];
@ -111,10 +113,10 @@ void ggml_allocr_alloc(struct ggml_allocr * alloc, struct ggml_tensor * tensor)
size_t max_avail = 0; size_t max_avail = 0;
// find the best fitting free block // find the best fitting free block besides the last block
int best_fit_block = -1; int best_fit_block = -1;
size_t best_fit_size = SIZE_MAX; size_t best_fit_size = SIZE_MAX;
for (int i = 0; i < alloc->n_free_blocks; i++) { for (int i = 0; i < alloc->n_free_blocks - 1; i++) {
struct free_block * block = &alloc->free_blocks[i]; struct free_block * block = &alloc->free_blocks[i];
max_avail = MAX(max_avail, block->size); max_avail = MAX(max_avail, block->size);
if (block->size >= size && block->size <= best_fit_size) { if (block->size >= size && block->size <= best_fit_size) {
@ -126,10 +128,17 @@ void ggml_allocr_alloc(struct ggml_allocr * alloc, struct ggml_tensor * tensor)
AT_PRINTF("block %d\n", best_fit_block); AT_PRINTF("block %d\n", best_fit_block);
if (best_fit_block == -1) { if (best_fit_block == -1) {
fprintf(stderr, "%s: not enough space in the buffer (needed %zu, largest block available %zu)\n", // the last block is our last resort
__func__, size, max_avail); struct free_block * block = &alloc->free_blocks[alloc->n_free_blocks - 1];
GGML_ASSERT(!"not enough space in the buffer"); if (block->size >= size) {
best_fit_block = alloc->n_free_blocks - 1;
max_avail = MAX(max_avail, block->size);
} else {
fprintf(stderr, "%s: not enough space in the buffer (needed %zu, largest block available %zu)\n",
__func__, size, max_avail);
GGML_ASSERT(!"not enough space in the buffer");
return; return;
}
} }
struct free_block * block = &alloc->free_blocks[best_fit_block]; struct free_block * block = &alloc->free_blocks[best_fit_block];
void * addr = block->addr; void * addr = block->addr;
@ -229,6 +238,17 @@ static void ggml_allocator_free_tensor(struct ggml_allocr * alloc, struct ggml_t
alloc->n_free_blocks++; alloc->n_free_blocks++;
} }
void ggml_allocr_set_parse_seq(struct ggml_allocr * alloc, int * list, int n) {
int pos = 0;
for (int i = 0; i < n; i++) {
if (list[i] != -1) {
alloc->parse_seq[pos] = list[i];
pos++;
}
}
alloc->has_parse_seq = true;
}
void ggml_allocr_reset(struct ggml_allocr * alloc) { void ggml_allocr_reset(struct ggml_allocr * alloc) {
alloc->n_free_blocks = 1; alloc->n_free_blocks = 1;
size_t align_offset = aligned_offset(alloc->data, 0, alloc->alignment); size_t align_offset = aligned_offset(alloc->data, 0, alloc->alignment);
@ -248,6 +268,8 @@ struct ggml_allocr * ggml_allocr_new(void * data, size_t size, size_t alignment)
/*.hash_table = */ {{0}}, /*.hash_table = */ {{0}},
/*.max_size = */ 0, /*.max_size = */ 0,
/*.measure = */ false, /*.measure = */ false,
/*.parse_seq = */ {0},
/*.has_parse_seq = */ false,
#ifdef GGML_ALLOCATOR_DEBUG #ifdef GGML_ALLOCATOR_DEBUG
/*.allocated_tensors = */ = {0}, /*.allocated_tensors = */ = {0},
#endif #endif
@ -275,6 +297,8 @@ struct ggml_allocr * ggml_allocr_new_measure(size_t alignment) {
/*.hash_table = */ {{0}}, /*.hash_table = */ {{0}},
/*.max_size = */ 0, /*.max_size = */ 0,
/*.measure = */ true, /*.measure = */ true,
/*.parse_seq = */ {0},
/*.has_parse_seq = */ false,
#ifdef GGML_ALLOCATOR_DEBUG #ifdef GGML_ALLOCATOR_DEBUG
/*.allocated_tensors = */ = {0}, /*.allocated_tensors = */ = {0},
#endif #endif
@ -473,7 +497,13 @@ static size_t ggml_allocator_alloc_graph_tensors_n(
allocate_node(alloc, input); allocate_node(alloc, input);
} }
} }
for (int i = 0; i < gf->n_nodes; i++) { for (int ind = 0; ind < gf->n_nodes; ind++) {
int i;
if (alloc->has_parse_seq) {
i = alloc->parse_seq[ind];
} else {
i = ind;
}
struct ggml_tensor * node = gf->nodes[i]; struct ggml_tensor * node = gf->nodes[i];
// allocate parents (leafs) // allocate parents (leafs)

View file

@ -10,6 +10,10 @@ extern "C" {
GGML_API struct ggml_allocr * ggml_allocr_new(void * data, size_t size, size_t alignment); GGML_API struct ggml_allocr * ggml_allocr_new(void * data, size_t size, size_t alignment);
GGML_API struct ggml_allocr * ggml_allocr_new_measure(size_t alignment); GGML_API struct ggml_allocr * ggml_allocr_new_measure(size_t alignment);
// tell the allocator to parse nodes following the order described in the list
// you should call this if your graph are optimized to execute out-of-order
GGML_API void ggml_allocr_set_parse_seq(struct ggml_allocr * alloc, int * list, int n);
GGML_API void ggml_allocr_free(struct ggml_allocr * alloc); GGML_API void ggml_allocr_free(struct ggml_allocr * alloc);
GGML_API bool ggml_allocr_is_measure(struct ggml_allocr * alloc); GGML_API bool ggml_allocr_is_measure(struct ggml_allocr * alloc);
GGML_API void ggml_allocr_reset(struct ggml_allocr * alloc); GGML_API void ggml_allocr_reset(struct ggml_allocr * alloc);

View file

@ -63,10 +63,13 @@ void ggml_metal_get_tensor(struct ggml_metal_context * ctx, struct ggml_tensor *
// try to find operations that can be run concurrently in the graph // try to find operations that can be run concurrently in the graph
// you should run it again if the topology of your graph changes // you should run it again if the topology of your graph changes
void ggml_metal_graph_find_concurrency(struct ggml_metal_context * ctx, struct ggml_cgraph * gf); void ggml_metal_graph_find_concurrency(struct ggml_metal_context * ctx, struct ggml_cgraph * gf, bool check_mem);
// if the graph has been optimized for concurrently dispatch // if the graph has been optimized for concurrently dispatch, return length of the concur_list if optimized
bool ggml_metal_if_optimized(struct ggml_metal_context * ctx); int ggml_metal_if_optimized(struct ggml_metal_context * ctx);
// output the concur_list for ggml_alloc
int * ggml_metal_get_concur_list(struct ggml_metal_context * ctx);
// same as ggml_graph_compute but uses Metal // same as ggml_graph_compute but uses Metal
// creates gf->n_threads command buffers in parallel // creates gf->n_threads command buffers in parallel

View file

@ -5,7 +5,6 @@
#import <Foundation/Foundation.h> #import <Foundation/Foundation.h>
#import <Metal/Metal.h> #import <Metal/Metal.h>
#import <MetalPerformanceShaders/MetalPerformanceShaders.h>
#undef MIN #undef MIN
#undef MAX #undef MAX
@ -79,6 +78,14 @@ struct ggml_metal_context {
GGML_METAL_DECL_KERNEL(mul_mat_q4_K_f32); GGML_METAL_DECL_KERNEL(mul_mat_q4_K_f32);
GGML_METAL_DECL_KERNEL(mul_mat_q5_K_f32); GGML_METAL_DECL_KERNEL(mul_mat_q5_K_f32);
GGML_METAL_DECL_KERNEL(mul_mat_q6_K_f32); GGML_METAL_DECL_KERNEL(mul_mat_q6_K_f32);
GGML_METAL_DECL_KERNEL(mul_mm_f16_f32);
GGML_METAL_DECL_KERNEL(mul_mm_q4_0_f32);
GGML_METAL_DECL_KERNEL(mul_mm_q4_1_f32);
GGML_METAL_DECL_KERNEL(mul_mm_q2_K_f32);
GGML_METAL_DECL_KERNEL(mul_mm_q3_K_f32);
GGML_METAL_DECL_KERNEL(mul_mm_q4_K_f32);
GGML_METAL_DECL_KERNEL(mul_mm_q5_K_f32);
GGML_METAL_DECL_KERNEL(mul_mm_q6_K_f32);
GGML_METAL_DECL_KERNEL(rope); GGML_METAL_DECL_KERNEL(rope);
GGML_METAL_DECL_KERNEL(alibi_f32); GGML_METAL_DECL_KERNEL(alibi_f32);
GGML_METAL_DECL_KERNEL(cpy_f32_f16); GGML_METAL_DECL_KERNEL(cpy_f32_f16);
@ -110,13 +117,6 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
ctx->n_buffers = 0; ctx->n_buffers = 0;
ctx->concur_list_len = 0; ctx->concur_list_len = 0;
// determine if we can use MPS
if (MPSSupportsMTLDevice(ctx->device)) {
fprintf(stderr, "%s: using MPS\n", __func__);
} else {
fprintf(stderr, "%s: not using MPS\n", __func__);
GGML_ASSERT(false && "MPS not supported");
}
#if 0 #if 0
// compile from source string and show compile log // compile from source string and show compile log
@ -163,10 +163,15 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
// load kernels // load kernels
{ {
NSError * error = nil;
#define GGML_METAL_ADD_KERNEL(name) \ #define GGML_METAL_ADD_KERNEL(name) \
ctx->function_##name = [ctx->library newFunctionWithName:@"kernel_"#name]; \ ctx->function_##name = [ctx->library newFunctionWithName:@"kernel_"#name]; \
ctx->pipeline_##name = [ctx->device newComputePipelineStateWithFunction:ctx->function_##name error:nil]; \ ctx->pipeline_##name = [ctx->device newComputePipelineStateWithFunction:ctx->function_##name error:&error]; \
fprintf(stderr, "%s: loaded %-32s %16p\n", __func__, "kernel_"#name, (void *) ctx->pipeline_##name); fprintf(stderr, "%s: loaded %-32s %16p\n", __func__, "kernel_"#name, (void *) ctx->pipeline_##name); \
if (error) { \
fprintf(stderr, "%s: load pipeline error: %s\n", __func__, [[error description] UTF8String]); \
return NULL; \
}
GGML_METAL_ADD_KERNEL(add); GGML_METAL_ADD_KERNEL(add);
GGML_METAL_ADD_KERNEL(add_row); GGML_METAL_ADD_KERNEL(add_row);
@ -196,6 +201,14 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
GGML_METAL_ADD_KERNEL(mul_mat_q4_K_f32); GGML_METAL_ADD_KERNEL(mul_mat_q4_K_f32);
GGML_METAL_ADD_KERNEL(mul_mat_q5_K_f32); GGML_METAL_ADD_KERNEL(mul_mat_q5_K_f32);
GGML_METAL_ADD_KERNEL(mul_mat_q6_K_f32); GGML_METAL_ADD_KERNEL(mul_mat_q6_K_f32);
GGML_METAL_ADD_KERNEL(mul_mm_f16_f32);
GGML_METAL_ADD_KERNEL(mul_mm_q4_0_f32);
GGML_METAL_ADD_KERNEL(mul_mm_q4_1_f32);
GGML_METAL_ADD_KERNEL(mul_mm_q2_K_f32);
GGML_METAL_ADD_KERNEL(mul_mm_q3_K_f32);
GGML_METAL_ADD_KERNEL(mul_mm_q4_K_f32);
GGML_METAL_ADD_KERNEL(mul_mm_q5_K_f32);
GGML_METAL_ADD_KERNEL(mul_mm_q6_K_f32);
GGML_METAL_ADD_KERNEL(rope); GGML_METAL_ADD_KERNEL(rope);
GGML_METAL_ADD_KERNEL(alibi_f32); GGML_METAL_ADD_KERNEL(alibi_f32);
GGML_METAL_ADD_KERNEL(cpy_f32_f16); GGML_METAL_ADD_KERNEL(cpy_f32_f16);
@ -228,11 +241,12 @@ void ggml_metal_set_n_cb(struct ggml_metal_context * ctx, int n_cb) {
ctx->n_cb = n_cb; ctx->n_cb = n_cb;
} }
bool ggml_metal_if_optimized(struct ggml_metal_context * ctx) { int ggml_metal_if_optimized(struct ggml_metal_context * ctx) {
if (ctx->concur_list_len) { return ctx->concur_list_len;
return true; }
}
return false; int * ggml_metal_get_concur_list(struct ggml_metal_context * ctx) {
return ctx->concur_list;
} }
// finds the Metal buffer that contains the tensor data on the GPU device // finds the Metal buffer that contains the tensor data on the GPU device
@ -375,7 +389,7 @@ void ggml_metal_get_tensor(
void ggml_metal_graph_find_concurrency( void ggml_metal_graph_find_concurrency(
struct ggml_metal_context * ctx, struct ggml_metal_context * ctx,
struct ggml_cgraph * gf) { struct ggml_cgraph * gf, bool check_mem) {
int search_depth = gf->n_nodes; //we only find concurrency in this range to avoid wasting too much time int search_depth = gf->n_nodes; //we only find concurrency in this range to avoid wasting too much time
int nodes_unused[GGML_MAX_CONCUR]; int nodes_unused[GGML_MAX_CONCUR];
@ -422,7 +436,7 @@ void ggml_metal_graph_find_concurrency(
} }
} }
} }
if (exe_flag) { if (exe_flag && check_mem) {
// check if nodes[i]'s data will be overwritten by a node before nodes[i]. // check if nodes[i]'s data will be overwritten by a node before nodes[i].
// if node[5] and node[3] write to the same memory region, then we can't issue node[5] before node[3] // if node[5] and node[3] write to the same memory region, then we can't issue node[5] before node[3]
int64_t data_start = (int64_t) gf->nodes[i]->data; int64_t data_start = (int64_t) gf->nodes[i]->data;
@ -506,7 +520,7 @@ void ggml_metal_graph_compute(
id<MTLCommandBuffer> command_buffer = command_buffers[cb_idx]; id<MTLCommandBuffer> command_buffer = command_buffers[cb_idx];
id<MTLComputeCommandEncoder> encoder = nil; id<MTLComputeCommandEncoder> encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
const int node_start = (cb_idx + 0) * n_nodes_per_cb; const int node_start = (cb_idx + 0) * n_nodes_per_cb;
const int node_end = (cb_idx == n_cb - 1) ? n_nodes : (cb_idx + 1) * n_nodes_per_cb; const int node_end = (cb_idx == n_cb - 1) ? n_nodes : (cb_idx + 1) * n_nodes_per_cb;
@ -515,10 +529,6 @@ void ggml_metal_graph_compute(
const int i = has_concur ? ctx->concur_list[ind] : ind; const int i = has_concur ? ctx->concur_list[ind] : ind;
if (i == -1) { if (i == -1) {
if (encoder == nil) {
encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
continue;
}
[encoder memoryBarrierWithScope:MTLBarrierScopeBuffers]; [encoder memoryBarrierWithScope:MTLBarrierScopeBuffers];
continue; continue;
} }
@ -592,10 +602,6 @@ void ggml_metal_graph_compute(
} break; } break;
case GGML_OP_ADD: case GGML_OP_ADD:
{ {
if (encoder == nil) {
encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
}
if (ggml_nelements(src1) == ne10) { if (ggml_nelements(src1) == ne10) {
// src1 is a row // src1 is a row
[encoder setComputePipelineState:ctx->pipeline_add_row]; [encoder setComputePipelineState:ctx->pipeline_add_row];
@ -613,10 +619,6 @@ void ggml_metal_graph_compute(
} break; } break;
case GGML_OP_MUL: case GGML_OP_MUL:
{ {
if (encoder == nil) {
encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
}
if (ggml_nelements(src1) == ne10) { if (ggml_nelements(src1) == ne10) {
// src1 is a row // src1 is a row
[encoder setComputePipelineState:ctx->pipeline_mul_row]; [encoder setComputePipelineState:ctx->pipeline_mul_row];
@ -634,10 +636,6 @@ void ggml_metal_graph_compute(
} break; } break;
case GGML_OP_SCALE: case GGML_OP_SCALE:
{ {
if (encoder == nil) {
encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
}
const float scale = *(const float *) src1->data; const float scale = *(const float *) src1->data;
[encoder setComputePipelineState:ctx->pipeline_scale]; [encoder setComputePipelineState:ctx->pipeline_scale];
@ -653,10 +651,6 @@ void ggml_metal_graph_compute(
switch (ggml_get_unary_op(gf->nodes[i])) { switch (ggml_get_unary_op(gf->nodes[i])) {
case GGML_UNARY_OP_SILU: case GGML_UNARY_OP_SILU:
{ {
if (encoder == nil) {
encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
}
[encoder setComputePipelineState:ctx->pipeline_silu]; [encoder setComputePipelineState:ctx->pipeline_silu];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1]; [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
@ -667,10 +661,6 @@ void ggml_metal_graph_compute(
} break; } break;
case GGML_UNARY_OP_RELU: case GGML_UNARY_OP_RELU:
{ {
if (encoder == nil) {
encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
}
[encoder setComputePipelineState:ctx->pipeline_relu]; [encoder setComputePipelineState:ctx->pipeline_relu];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1]; [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
@ -681,10 +671,6 @@ void ggml_metal_graph_compute(
} break; } break;
case GGML_UNARY_OP_GELU: case GGML_UNARY_OP_GELU:
{ {
if (encoder == nil) {
encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
}
[encoder setComputePipelineState:ctx->pipeline_gelu]; [encoder setComputePipelineState:ctx->pipeline_gelu];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1]; [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
@ -701,10 +687,6 @@ void ggml_metal_graph_compute(
} break; } break;
case GGML_OP_SOFT_MAX: case GGML_OP_SOFT_MAX:
{ {
if (encoder == nil) {
encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
}
const int nth = 32; const int nth = 32;
[encoder setComputePipelineState:ctx->pipeline_soft_max]; [encoder setComputePipelineState:ctx->pipeline_soft_max];
@ -719,10 +701,6 @@ void ggml_metal_graph_compute(
} break; } break;
case GGML_OP_DIAG_MASK_INF: case GGML_OP_DIAG_MASK_INF:
{ {
if (encoder == nil) {
encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
}
const int n_past = ((int32_t *)(dst->op_params))[0]; const int n_past = ((int32_t *)(dst->op_params))[0];
[encoder setComputePipelineState:ctx->pipeline_diag_mask_inf]; [encoder setComputePipelineState:ctx->pipeline_diag_mask_inf];
@ -740,53 +718,43 @@ void ggml_metal_graph_compute(
GGML_ASSERT(ne00 == ne10); GGML_ASSERT(ne00 == ne10);
// GGML_ASSERT(ne02 == ne12); // Should be checked on individual data types until broadcast is implemented everywhere // GGML_ASSERT(ne02 == ne12); // Should be checked on individual data types until broadcast is implemented everywhere
uint gqa = ne12/ne02;
GGML_ASSERT(ne03 == ne13); GGML_ASSERT(ne03 == ne13);
// for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
// AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
if (ggml_is_contiguous(src0) && if (ggml_is_contiguous(src0) &&
ggml_is_contiguous(src1) && ggml_is_contiguous(src1) &&
(src0t == GGML_TYPE_F32 || src0t == GGML_TYPE_F16) && ne11 > 1) { src1t == GGML_TYPE_F32 &&
[ctx->device supportsFamily:MTLGPUFamilyApple7] &&
if (encoder != nil) { ne00%32 == 0 &&
[encoder endEncoding]; ne11 > 1) {
encoder = nil; switch (src0->type) {
case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_mul_mm_f16_f32]; break;
case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_0_f32]; break;
case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_1_f32]; break;
case GGML_TYPE_Q2_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q2_K_f32]; break;
case GGML_TYPE_Q3_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q3_K_f32]; break;
case GGML_TYPE_Q4_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_K_f32]; break;
case GGML_TYPE_Q5_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q5_K_f32]; break;
case GGML_TYPE_Q6_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q6_K_f32]; break;
default: GGML_ASSERT(false && "MUL MAT-MAT not implemented");
}
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:5];
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:6];
[encoder setBytes:&ne12 length:sizeof(ne12) atIndex:7];
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:8];
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:9];
[encoder setBytes:&gqa length:sizeof(gqa) atIndex:10];
[encoder setThreadgroupMemoryLength:8192 atIndex:0];
[encoder dispatchThreadgroups:MTLSizeMake( (ne11+31)/32, (ne01+63) / 64, ne12) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
} }
else {
MPSDataType src0dt = src0t == GGML_TYPE_F32 ? MPSDataTypeFloat32 : MPSDataTypeFloat16;
MPSDataType src1dt = src1t == GGML_TYPE_F32 ? MPSDataTypeFloat32 : MPSDataTypeFloat16;
// for F32 x F32 we use MPS
MPSMatrixDescriptor * desc0 = [MPSMatrixDescriptor
matrixDescriptorWithRows:ne01 columns:ne00 rowBytes:src0->nb[1] dataType:src0dt];
MPSMatrixDescriptor * desc1 = [MPSMatrixDescriptor
matrixDescriptorWithRows:ne11 columns:ne10 rowBytes:src1->nb[1] dataType:src1dt];
MPSMatrixDescriptor * desc = [MPSMatrixDescriptor
matrixDescriptorWithRows:ne1 columns:ne0 rowBytes:dst->nb[1] dataType:MPSDataTypeFloat32];
MPSMatrixMultiplication * mul = [[MPSMatrixMultiplication alloc]
initWithDevice:ctx->device transposeLeft:false transposeRight:true
resultRows:ne11 resultColumns:ne01 interiorColumns:ne00 alpha:1.0 beta:0.0];
// we need to do ne12 multiplications
// TODO: is there a way to do this in parallel - currently very slow ..
// TODO: might be possible to offload part of the computation to ANE using Accelerate's CBLAS
for (int64_t i02 = 0; i02 < ne12; ++i02) {
size_t offs_src0_cur = offs_src0 + i02/(ne12/ne02)*nb02; // gqa not used for now
size_t offs_src1_cur = offs_src1 + i02*nb12;
size_t offs_dst_cur = offs_dst + i02*nb2;
MPSMatrix * mat_src0 = [[MPSMatrix alloc] initWithBuffer:id_src0 offset:offs_src0_cur descriptor:desc0];
MPSMatrix * mat_src1 = [[MPSMatrix alloc] initWithBuffer:id_src1 offset:offs_src1_cur descriptor:desc1];
MPSMatrix * mat_dst = [[MPSMatrix alloc] initWithBuffer:id_dst offset:offs_dst_cur descriptor:desc ];
[mul encodeToCommandBuffer:command_buffer leftMatrix:mat_src1 rightMatrix:mat_src0 resultMatrix:mat_dst];
}
} else {
if (encoder == nil) {
encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
}
int nth0 = 32; int nth0 = 32;
int nth1 = 1; int nth1 = 1;
@ -885,23 +853,24 @@ void ggml_metal_graph_compute(
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:14]; [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:14];
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:15]; [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:15];
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:16]; [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:16];
[encoder setBytes:&gqa length:sizeof(gqa) atIndex:17];
if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 || if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 ||
src0t == GGML_TYPE_Q2_K || src0t == GGML_TYPE_Q4_K) { src0t == GGML_TYPE_Q2_K || src0t == GGML_TYPE_Q4_K) {
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7) / 8, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7) / 8, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
} }
else if (src0t == GGML_TYPE_Q3_K) { else if (src0t == GGML_TYPE_Q3_K) {
#ifdef GGML_QKK_64 #ifdef GGML_QKK_64
[encoder dispatchThreadgroups:MTLSizeMake((ne01+1)/2, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; [encoder dispatchThreadgroups:MTLSizeMake((ne01+1)/2, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
#else #else
[encoder dispatchThreadgroups:MTLSizeMake((ne01+3)/4, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; [encoder dispatchThreadgroups:MTLSizeMake((ne01+3)/4, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
#endif #endif
} }
else if (src0t == GGML_TYPE_Q5_K) { else if (src0t == GGML_TYPE_Q5_K) {
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3) / 4, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3) / 4, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
} }
else if (src0t == GGML_TYPE_Q6_K) { else if (src0t == GGML_TYPE_Q6_K) {
[encoder dispatchThreadgroups:MTLSizeMake((ne01+1)/2, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; [encoder dispatchThreadgroups:MTLSizeMake((ne01+1)/2, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
} else { } else {
[encoder setThreadgroupMemoryLength:nth0*sizeof(float) atIndex:0]; [encoder setThreadgroupMemoryLength:nth0*sizeof(float) atIndex:0];
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
@ -910,10 +879,6 @@ void ggml_metal_graph_compute(
} break; } break;
case GGML_OP_GET_ROWS: case GGML_OP_GET_ROWS:
{ {
if (encoder == nil) {
encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
}
switch (src0->type) { switch (src0->type) {
case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_get_rows_f16]; break; case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_get_rows_f16]; break;
case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_0]; break; case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_0]; break;
@ -939,10 +904,6 @@ void ggml_metal_graph_compute(
} break; } break;
case GGML_OP_RMS_NORM: case GGML_OP_RMS_NORM:
{ {
if (encoder == nil) {
encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
}
float eps; float eps;
memcpy(&eps, dst->op_params, sizeof(float)); memcpy(&eps, dst->op_params, sizeof(float));
@ -962,10 +923,6 @@ void ggml_metal_graph_compute(
} break; } break;
case GGML_OP_NORM: case GGML_OP_NORM:
{ {
if (encoder == nil) {
encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
}
const float eps = 1e-5f; const float eps = 1e-5f;
const int nth = 256; const int nth = 256;
@ -984,10 +941,6 @@ void ggml_metal_graph_compute(
} break; } break;
case GGML_OP_ALIBI: case GGML_OP_ALIBI:
{ {
if (encoder == nil) {
encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
}
GGML_ASSERT((src0t == GGML_TYPE_F32)); GGML_ASSERT((src0t == GGML_TYPE_F32));
const int n_past = ((int32_t *) dst->op_params)[0]; UNUSED(n_past); const int n_past = ((int32_t *) dst->op_params)[0]; UNUSED(n_past);
@ -1027,10 +980,6 @@ void ggml_metal_graph_compute(
} break; } break;
case GGML_OP_ROPE: case GGML_OP_ROPE:
{ {
if (encoder == nil) {
encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
}
const int n_past = ((int32_t *) dst->op_params)[0]; const int n_past = ((int32_t *) dst->op_params)[0];
const int n_dims = ((int32_t *) dst->op_params)[1]; const int n_dims = ((int32_t *) dst->op_params)[1];
const int mode = ((int32_t *) dst->op_params)[2]; const int mode = ((int32_t *) dst->op_params)[2];
@ -1071,10 +1020,6 @@ void ggml_metal_graph_compute(
case GGML_OP_CPY: case GGML_OP_CPY:
case GGML_OP_CONT: case GGML_OP_CONT:
{ {
if (encoder == nil) {
encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
}
const int nth = 32; const int nth = 32;
switch (src0t) { switch (src0t) {

File diff suppressed because it is too large Load diff

View file

@ -63,7 +63,7 @@ static void llama_log_callback_default(llama_log_level level, const char * text,
#define LLAMA_LOG_ERROR(...) llama_log_internal(LLAMA_LOG_LEVEL_ERROR, __VA_ARGS__) #define LLAMA_LOG_ERROR(...) llama_log_internal(LLAMA_LOG_LEVEL_ERROR, __VA_ARGS__)
#if !defined(GGML_USE_CUBLAS) && !defined(GGML_USE_METAL) #if !defined(GGML_USE_CUBLAS)
#include "ggml-alloc.h" #include "ggml-alloc.h"
#define LLAMA_USE_ALLOCATOR #define LLAMA_USE_ALLOCATOR
#else #else
@ -1609,11 +1609,11 @@ static struct ggml_cgraph * llama_build_graph(
ggml_set_name(Q, "Q"); ggml_set_name(Q, "Q");
struct ggml_tensor * K = struct ggml_tensor * K =
ggml_permute(ctx0, ggml_view_3d(ctx0, kv_self.k,
ggml_reshape_3d(ctx0, n_embd_head, n_past + N, n_head_kv,
ggml_view_1d(ctx0, kv_self.k, (n_past + N)*n_embd_gqa, il*n_ctx*ggml_element_size(kv_self.k)*n_embd_gqa), ggml_element_size(kv_self.k)*n_embd_gqa,
n_embd_head, n_head_kv, n_past + N), ggml_element_size(kv_self.k)*n_embd_head,
0, 2, 1, 3); ggml_element_size(kv_self.k)*n_embd_gqa*n_ctx*il);
offload_func_kq(K); offload_func_kq(K);
ggml_set_name(K, "K"); ggml_set_name(K, "K");
@ -1642,9 +1642,9 @@ static struct ggml_cgraph * llama_build_graph(
struct ggml_tensor * V = struct ggml_tensor * V =
ggml_view_3d(ctx0, kv_self.v, ggml_view_3d(ctx0, kv_self.v,
n_past + N, n_embd_head, n_head_kv, n_past + N, n_embd_head, n_head_kv,
n_ctx*ggml_element_size(kv_self.v), ggml_element_size(kv_self.v)*n_ctx,
n_ctx*ggml_element_size(kv_self.v)*n_embd_head, ggml_element_size(kv_self.v)*n_ctx*n_embd_head,
n_ctx*ggml_element_size(kv_self.v)*n_embd_gqa*il); ggml_element_size(kv_self.v)*n_ctx*n_embd_gqa*il);
offload_func_v(V); offload_func_v(V);
ggml_set_name(V, "V"); ggml_set_name(V, "V");
@ -1845,11 +1845,7 @@ static bool llama_eval_internal(
#endif #endif
#ifdef GGML_USE_METAL #ifdef GGML_USE_METAL
if (lctx.ctx_metal && N == 1) { if (lctx.ctx_metal) {
// TODO: disabled until #2413 is resolved
//if (!ggml_metal_if_optimized(lctx.ctx_metal)) {
// ggml_metal_graph_find_concurrency(lctx.ctx_metal, gf);
//}
ggml_metal_set_n_cb (lctx.ctx_metal, n_threads); ggml_metal_set_n_cb (lctx.ctx_metal, n_threads);
ggml_metal_graph_compute(lctx.ctx_metal, gf); ggml_metal_graph_compute(lctx.ctx_metal, gf);
ggml_metal_get_tensor (lctx.ctx_metal, res); ggml_metal_get_tensor (lctx.ctx_metal, res);
@ -1857,22 +1853,6 @@ static bool llama_eval_internal(
ggml_metal_get_tensor(lctx.ctx_metal, embeddings); ggml_metal_get_tensor(lctx.ctx_metal, embeddings);
} }
} else { } else {
// IMPORTANT:
// Since we don't have efficient Matrix x Matrix Metal multiplication yet, we fallback to vanilla
// ggml_graph_compute(). It uses Apple's Accelerate CBLAS API which takes advantage of the ANE or the AMX
// coprocessor.
//
// When we implement Matrix x Matrix Metal multiplication, we can avoid this branch.
// But for now, we have focused only on Matrix x Vector Metal multiplication.
//
// TODO: avoid these syncs via shared memory (ref #1696)
//
if (lctx.ctx_metal) {
// We need to sync the GPU KV cache with the CPU KV cache
ggml_metal_get_tensor(lctx.ctx_metal, kv_self.k);
ggml_metal_get_tensor(lctx.ctx_metal, kv_self.v);
}
ggml_graph_compute_helper(lctx.work_buffer, gf, n_threads); ggml_graph_compute_helper(lctx.work_buffer, gf, n_threads);
} }
#else #else
@ -3303,7 +3283,18 @@ struct llama_context * llama_new_context_with_model(
int n_past = hparams.n_ctx - n_tokens; int n_past = hparams.n_ctx - n_tokens;
llama_token token = llama_token_bos(); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph llama_token token = llama_token_bos(); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph
ggml_cgraph * gf = llama_build_graph(*ctx, &token, NULL, n_tokens, n_past); ggml_cgraph * gf = llama_build_graph(*ctx, &token, NULL, n_tokens, n_past);
#ifdef GGML_USE_METAL
if (params.n_gpu_layers > 0) {
ctx->ctx_metal = ggml_metal_init(1);
if (!ctx->ctx_metal) {
LLAMA_LOG_ERROR("%s: ggml_metal_init() failed\n", __func__);
llama_free(ctx);
return NULL;
}
ggml_metal_graph_find_concurrency(ctx->ctx_metal, gf, false);
ggml_allocr_set_parse_seq(ctx->alloc, ggml_metal_get_concur_list(ctx->ctx_metal), ggml_metal_if_optimized(ctx->ctx_metal));
}
#endif
// measure memory requirements for the graph // measure memory requirements for the graph
size_t alloc_size = ggml_allocr_alloc_graph(ctx->alloc, gf) + tensor_alignment; size_t alloc_size = ggml_allocr_alloc_graph(ctx->alloc, gf) + tensor_alignment;
@ -3321,6 +3312,11 @@ struct llama_context * llama_new_context_with_model(
ctx->buf_alloc.resize(alloc_size); ctx->buf_alloc.resize(alloc_size);
ctx->alloc = ggml_allocr_new(ctx->buf_alloc.addr, ctx->buf_alloc.size, tensor_alignment); ctx->alloc = ggml_allocr_new(ctx->buf_alloc.addr, ctx->buf_alloc.size, tensor_alignment);
#ifdef GGML_USE_METAL
if (ctx->ctx_metal) {
ggml_allocr_set_parse_seq(ctx->alloc, ggml_metal_get_concur_list(ctx->ctx_metal), ggml_metal_if_optimized(ctx->ctx_metal));
}
#endif
} }
#else #else
ctx->buf_compute.resize(MEM_REQ_EVAL().at(ctx->model.type) + ggml_graph_overhead()); ctx->buf_compute.resize(MEM_REQ_EVAL().at(ctx->model.type) + ggml_graph_overhead());
@ -3335,13 +3331,6 @@ struct llama_context * llama_new_context_with_model(
#ifdef GGML_USE_METAL #ifdef GGML_USE_METAL
if (params.n_gpu_layers > 0) { if (params.n_gpu_layers > 0) {
// this allocates all Metal resources and memory buffers // this allocates all Metal resources and memory buffers
ctx->ctx_metal = ggml_metal_init(1);
if (!ctx->ctx_metal) {
LLAMA_LOG_ERROR("%s: ggml_metal_init() failed\n", __func__);
llama_free(ctx);
return NULL;
}
void * data_ptr = NULL; void * data_ptr = NULL;
size_t data_size = 0; size_t data_size = 0;
@ -3370,8 +3359,7 @@ struct llama_context * llama_new_context_with_model(
LLAMA_METAL_CHECK_BUF(ggml_metal_add_buffer(ctx->ctx_metal, "eval", ctx->buf_compute.addr, ctx->buf_compute.size, 0)); LLAMA_METAL_CHECK_BUF(ggml_metal_add_buffer(ctx->ctx_metal, "eval", ctx->buf_compute.addr, ctx->buf_compute.size, 0));
LLAMA_METAL_CHECK_BUF(ggml_metal_add_buffer(ctx->ctx_metal, "kv", ctx->kv_self.buf.addr, ctx->kv_self.buf.size, 0)); LLAMA_METAL_CHECK_BUF(ggml_metal_add_buffer(ctx->ctx_metal, "kv", ctx->kv_self.buf.addr, ctx->kv_self.buf.size, 0));
LLAMA_METAL_CHECK_BUF(ggml_metal_add_buffer(ctx->ctx_metal, "scr0", ctx->buf_scratch[0].addr, ctx->buf_scratch[0].size, 0)); LLAMA_METAL_CHECK_BUF(ggml_metal_add_buffer(ctx->ctx_metal, "alloc", ctx->buf_alloc.addr, ctx->buf_alloc.size, 0));
LLAMA_METAL_CHECK_BUF(ggml_metal_add_buffer(ctx->ctx_metal, "scr1", ctx->buf_scratch[1].addr, ctx->buf_scratch[1].size, 0));
#undef LLAMA_METAL_CHECK_BUF #undef LLAMA_METAL_CHECK_BUF
} }
#endif #endif

View file

@ -0,0 +1,3 @@
#!/bin/bash
wget https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-raw-v1.zip

View file

@ -12,5 +12,6 @@ llama_add_test(test-quantize-perf.cpp)
llama_add_test(test-sampling.cpp) llama_add_test(test-sampling.cpp)
llama_add_test(test-tokenizer-0.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab.bin) llama_add_test(test-tokenizer-0.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab.bin)
llama_add_test(test-grammar-parser.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../examples/grammar-parser.cpp) llama_add_test(test-grammar-parser.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../examples/grammar-parser.cpp)
llama_add_test(test-llama-grammar.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../examples/grammar-parser.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../llama.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../examples/common.cpp)
llama_add_test(test-grad0.cpp) # SLOW llama_add_test(test-grad0.cpp) # SLOW
# llama_add_test(test-opt.cpp) # SLOW # llama_add_test(test-opt.cpp) # SLOW

View file

@ -0,0 +1,403 @@
#ifdef NDEBUG
#undef NDEBUG
#endif
#include "llama.cpp"
#include "examples/common.cpp"
#include "examples/grammar-parser.cpp"
#include <cassert>
int main()
{
grammar_parser::parse_state parsed_grammar;
std::vector<std::pair<std::string, uint32_t>> expected = {
{"expr", 2},
{"expr_6", 6},
{"expr_7", 7},
{"ident", 8},
{"ident_10", 10},
{"num", 9},
{"num_11", 11},
{"root", 0},
{"root_1", 1},
{"root_5", 5},
{"term", 4},
{"ws", 3},
{"ws_12", 12},
};
std::vector<std::vector<llama_grammar_element>> expected_rules = {
{{LLAMA_GRETYPE_RULE_REF, 5}, {LLAMA_GRETYPE_END, 0}},
{
{LLAMA_GRETYPE_RULE_REF, 2},
{LLAMA_GRETYPE_CHAR, 61},
{LLAMA_GRETYPE_RULE_REF, 3},
{LLAMA_GRETYPE_RULE_REF, 4},
{LLAMA_GRETYPE_CHAR, 10},
{LLAMA_GRETYPE_END, 0},
},
{{LLAMA_GRETYPE_RULE_REF, 4}, {LLAMA_GRETYPE_RULE_REF, 7}, {LLAMA_GRETYPE_END, 0}},
{{LLAMA_GRETYPE_RULE_REF, 12}, {LLAMA_GRETYPE_END, 0}},
{
{LLAMA_GRETYPE_RULE_REF, 8},
{LLAMA_GRETYPE_ALT, 0},
{LLAMA_GRETYPE_RULE_REF, 9},
{LLAMA_GRETYPE_ALT, 0},
{LLAMA_GRETYPE_CHAR, 40},
{LLAMA_GRETYPE_RULE_REF, 3},
{LLAMA_GRETYPE_RULE_REF, 2},
{LLAMA_GRETYPE_CHAR, 41},
{LLAMA_GRETYPE_RULE_REF, 3},
{LLAMA_GRETYPE_END, 0},
},
{{LLAMA_GRETYPE_RULE_REF, 1}, {LLAMA_GRETYPE_RULE_REF, 5}, {LLAMA_GRETYPE_ALT, 0}, {LLAMA_GRETYPE_RULE_REF, 1}, {LLAMA_GRETYPE_END, 0}},
{
{LLAMA_GRETYPE_CHAR, 45},
{LLAMA_GRETYPE_CHAR_ALT, 43},
{LLAMA_GRETYPE_CHAR_ALT, 42},
{LLAMA_GRETYPE_CHAR_ALT, 47},
{LLAMA_GRETYPE_RULE_REF, 4},
{LLAMA_GRETYPE_END, 0},
},
{{LLAMA_GRETYPE_RULE_REF, 6}, {LLAMA_GRETYPE_RULE_REF, 7}, {LLAMA_GRETYPE_ALT, 0}, {LLAMA_GRETYPE_END, 0}},
{
{LLAMA_GRETYPE_CHAR, 97},
{LLAMA_GRETYPE_CHAR_RNG_UPPER, 122},
{LLAMA_GRETYPE_RULE_REF, 10},
{LLAMA_GRETYPE_RULE_REF, 3},
{LLAMA_GRETYPE_END, 0},
},
{{LLAMA_GRETYPE_RULE_REF, 11}, {LLAMA_GRETYPE_RULE_REF, 3}, {LLAMA_GRETYPE_END, 0}},
{
{LLAMA_GRETYPE_CHAR, 97},
{LLAMA_GRETYPE_CHAR_RNG_UPPER, 122},
{LLAMA_GRETYPE_CHAR_ALT, 48},
{LLAMA_GRETYPE_CHAR_RNG_UPPER, 57},
{LLAMA_GRETYPE_CHAR_ALT, 95},
{LLAMA_GRETYPE_RULE_REF, 10},
{LLAMA_GRETYPE_ALT, 0},
{LLAMA_GRETYPE_END, 0},
},
{
{LLAMA_GRETYPE_CHAR, 48},
{LLAMA_GRETYPE_CHAR_RNG_UPPER, 57},
{LLAMA_GRETYPE_RULE_REF, 11},
{LLAMA_GRETYPE_ALT, 0},
{LLAMA_GRETYPE_CHAR, 48},
{LLAMA_GRETYPE_CHAR_RNG_UPPER, 57},
{LLAMA_GRETYPE_END, 0},
},
{
{LLAMA_GRETYPE_CHAR, 32},
{LLAMA_GRETYPE_CHAR_ALT, 9},
{LLAMA_GRETYPE_CHAR_ALT, 10},
{LLAMA_GRETYPE_RULE_REF, 12},
{LLAMA_GRETYPE_ALT, 0},
{LLAMA_GRETYPE_END, 0},
},
};
for (auto pair : expected)
{
parsed_grammar.symbol_ids[pair.first] = pair.second;
}
for (auto rule : expected_rules)
{
parsed_grammar.rules.push_back({});
for (auto element : rule)
{
parsed_grammar.rules.back().push_back(element);
}
}
llama_grammar *grammar = NULL;
std::vector<const llama_grammar_element *> grammar_rules(parsed_grammar.c_rules());
grammar = llama_grammar_init(
grammar_rules.data(), grammar_rules.size(), parsed_grammar.symbol_ids.at("root"));
std::vector<std::vector<llama_grammar_element>> expected_stacks = {
{
{LLAMA_GRETYPE_RULE_REF, 5},
{LLAMA_GRETYPE_CHAR, 61},
{LLAMA_GRETYPE_RULE_REF, 7},
{LLAMA_GRETYPE_CHAR, 97},
},
{
{LLAMA_GRETYPE_RULE_REF, 5},
{LLAMA_GRETYPE_CHAR, 61},
{LLAMA_GRETYPE_RULE_REF, 7},
{LLAMA_GRETYPE_RULE_REF, 3},
{LLAMA_GRETYPE_CHAR, 48},
},
{
{LLAMA_GRETYPE_RULE_REF, 5},
{LLAMA_GRETYPE_CHAR, 61},
{LLAMA_GRETYPE_RULE_REF, 7},
{LLAMA_GRETYPE_RULE_REF, 3},
{LLAMA_GRETYPE_CHAR, 48},
},
{
{LLAMA_GRETYPE_RULE_REF, 5},
{LLAMA_GRETYPE_CHAR, 61},
{LLAMA_GRETYPE_RULE_REF, 7},
{LLAMA_GRETYPE_CHAR, 40},
},
{
{LLAMA_GRETYPE_CHAR, 61},
{LLAMA_GRETYPE_RULE_REF, 7},
{LLAMA_GRETYPE_CHAR, 97},
},
{
{LLAMA_GRETYPE_CHAR, 61},
{LLAMA_GRETYPE_RULE_REF, 7},
{LLAMA_GRETYPE_RULE_REF, 3},
{LLAMA_GRETYPE_CHAR, 48},
},
{
{LLAMA_GRETYPE_CHAR, 61},
{LLAMA_GRETYPE_RULE_REF, 7},
{LLAMA_GRETYPE_RULE_REF, 3},
{LLAMA_GRETYPE_CHAR, 48},
},
{
{LLAMA_GRETYPE_CHAR, 61},
{LLAMA_GRETYPE_RULE_REF, 7},
{LLAMA_GRETYPE_CHAR, 40},
}};
auto index = 0;
for (auto stack : grammar->stacks)
{
// compare stack to expected_stack
for (uint32_t i = 0; i < stack.size(); i++)
{
auto element = stack[i];
auto expected_element = expected_stacks[index][i];
// pretty print error message before asserting
if (expected_element.type != element->type || expected_element.value != element->value)
{
fprintf(stderr, "index: %d\n", index);
fprintf(stderr, "expected_element: %d, %d\n", expected_element.type, expected_element.value);
fprintf(stderr, "actual_element: %d, %d\n", element->type, element->value);
fprintf(stderr, "expected_element != actual_element\n");
}
assert(expected_element.type == element->type && expected_element.value == element->value);
}
index++;
}
std::vector<std::vector<const llama_grammar_element *>> next_stacks;
std::vector<llama_grammar_candidate> next_candidates;
next_candidates.resize(24);
for (size_t i = 0; i < 24; ++i)
{
uint32_t *cp = new uint32_t[2]; // dynamically allocate memory for code_point
cp[0] = 37 + i;
cp[1] = 0;
next_candidates[i] = {i, cp};
}
std::vector<std::vector<std::pair<uint32_t, uint16_t>>> expected_reject = {
{
{0, 37},
{1, 38},
{2, 39},
{3, 40},
{4, 41},
{5, 42},
{6, 43},
{7, 44},
{8, 45},
{9, 46},
{10, 47},
{11, 48},
{12, 49},
{13, 50},
{14, 51},
{15, 52},
{16, 53},
{17, 54},
{18, 55},
{19, 56},
{20, 57},
{21, 58},
{22, 59},
{23, 60},
},
{
{0, 37},
{1, 38},
{2, 39},
{3, 40},
{4, 41},
{5, 42},
{6, 43},
{7, 44},
{8, 45},
{9, 46},
{10, 47},
{21, 58},
{22, 59},
{23, 60},
},
{
{0, 37},
{1, 38},
{2, 39},
{3, 40},
{4, 41},
{5, 42},
{6, 43},
{7, 44},
{8, 45},
{9, 46},
{10, 47},
{21, 58},
{22, 59},
{23, 60},
},
{
{0, 37},
{1, 38},
{2, 39},
{4, 41},
{5, 42},
{6, 43},
{7, 44},
{8, 45},
{9, 46},
{10, 47},
{11, 48},
{12, 49},
{13, 50},
{14, 51},
{15, 52},
{16, 53},
{17, 54},
{18, 55},
{19, 56},
{20, 57},
{21, 58},
{22, 59},
{23, 60},
},
{
{0, 37},
{1, 38},
{2, 39},
{3, 40},
{4, 41},
{5, 42},
{6, 43},
{7, 44},
{8, 45},
{9, 46},
{10, 47},
{11, 48},
{12, 49},
{13, 50},
{14, 51},
{15, 52},
{16, 53},
{17, 54},
{18, 55},
{19, 56},
{20, 57},
{21, 58},
{22, 59},
{23, 60},
},
{
{0, 37},
{1, 38},
{2, 39},
{3, 40},
{4, 41},
{5, 42},
{6, 43},
{7, 44},
{8, 45},
{9, 46},
{10, 47},
{21, 58},
{22, 59},
{23, 60},
},
{
{0, 37},
{1, 38},
{2, 39},
{3, 40},
{4, 41},
{5, 42},
{6, 43},
{7, 44},
{8, 45},
{9, 46},
{10, 47},
{21, 58},
{22, 59},
{23, 60},
},
{
{0, 37},
{1, 38},
{2, 39},
{4, 41},
{5, 42},
{6, 43},
{7, 44},
{8, 45},
{9, 46},
{10, 47},
{11, 48},
{12, 49},
{13, 50},
{14, 51},
{15, 52},
{16, 53},
{17, 54},
{18, 55},
{19, 56},
{20, 57},
{21, 58},
{22, 59},
{23, 60},
},
};
std::vector<llama_grammar_candidate> rejects = llama_grammar_reject_candidates_for_stack(grammar->rules, grammar->stacks[0], next_candidates);
std::vector<std::vector<llama_grammar_candidate>> all_rejects;
for (std::size_t count = 0; count < grammar->stacks.size(); ++count)
{
rejects = llama_grammar_reject_candidates_for_stack(grammar->rules, grammar->stacks[count], next_candidates);
all_rejects.push_back(rejects);
}
index = 0;
for (auto rej : all_rejects)
{
for (uint32_t i = 0; i < rej.size(); i++)
{
auto element = rej[i];
auto expected_element = expected_reject[index][i];
assert(element.index == expected_element.first && *element.code_points == expected_element.second);
}
index++;
}
for (auto &candidate : next_candidates)
{
delete[] candidate.code_points;
candidate.code_points = nullptr;
}
delete grammar;
return 0;
}