Clean up and add more correctness

This commit is contained in:
Jason Flax 2024-11-06 01:04:09 -05:00
parent 802687a4d8
commit b99e7f977f
25 changed files with 1095 additions and 390 deletions

View file

@ -32,7 +32,7 @@ var cSettings: [CSetting] = [
// We should consider add this in the future when we drop support for iOS 14
// (ref: ref: https://developer.apple.com/documentation/accelerate/1513264-cblas_sgemm?language=objc)
.define("ACCELERATE_NEW_LAPACK"),
.define("ACCELERATE_LAPACK_ILP64")
.define("ACCELERATE_LAPACK_ILP64"),
]
#if canImport(Darwin)
@ -80,7 +80,8 @@ let package = Package(
"ggml"
],
sources: cppSources,
publicHeadersPath: "spm-headers"),
publicHeadersPath: "spm-headers",
cSettings: cSettings),
.target(
name: "llama",
dependencies: ["llama_cpp"],
@ -94,6 +95,8 @@ let package = Package(
dependencies: ["llama"],
path: "objc",
sources: [
"CPUParams.mm",
"GGMLThreadpool.mm",
"GPTParams.mm",
"GPTSampler.mm",
"LlamaBatch.mm",

68
objc/CPUParams.mm Normal file
View file

@ -0,0 +1,68 @@
#import <Foundation/Foundation.h>
#import "CPUParams_Private.hpp"
#import "GGMLThreadpool_Private.hpp"
@implementation CPUParams
- (instancetype)initWithParams:(cpu_params&)params
{
self = [super init];
if (self) {
self->params = &params;
}
return self;
}
- (NSInteger)nThreads {
return params->n_threads;
}
- (void)setNThreads:(NSInteger)nThreads {
params->n_threads = nThreads;
}
- (BOOL)maskValid {
return params->mask_valid;
}
- (void)setMaskValid:(BOOL)maskValid {
params->mask_valid = maskValid;
}
- (GGMLSchedPriority)priority {
return GGMLSchedPriority(params->priority);
}
- (void)setPriority:(GGMLSchedPriority)priority {
params->priority = ggml_sched_priority(priority);
}
- (BOOL)strictCPU {
return params->strict_cpu;
}
- (void)setStrictCPU:(BOOL)strictCPU {
params->strict_cpu = strictCPU;
}
- (NSUInteger)poll {
return params->poll;
}
- (void)setPoll:(NSUInteger)poll {
params->poll = poll;
}
- (BOOL)getCpuMaskAtIndex:(NSUInteger)index {
return params->cpumask[index];
}
- (void)setCpuMask:(BOOL)value atIndex:(NSUInteger)index {
params->cpumask[index] = value;
}
- (GGMLThreadpoolParams *)ggmlThreadpoolParams {
return [[GGMLThreadpoolParams alloc] initWithParams:ggml_threadpool_params_from_cpu_params(*params)];
}
@end

52
objc/GGMLThreadpool.mm Normal file
View file

@ -0,0 +1,52 @@
#import <Foundation/Foundation.h>
#import "GGMLThreadpool_Private.hpp"
@implementation GGMLThreadpool
- (instancetype)initWithThreadpool:(ggml_threadpool *)threadpool
{
self = [super init];
if (self) {
self->threadpool = threadpool;
}
return self;
}
- (ggml_threadpool *)threadpool {
return self->threadpool;
}
@end
@implementation GGMLThreadpoolParams {
ggml_threadpool_params params;
}
- (BOOL)getCpuMaskAtIndex:(NSUInteger)index {
abort();
}
- (void)setCpuMask:(BOOL)value atIndex:(NSUInteger)index {
abort();
}
- (instancetype)initWithParams:(ggml_threadpool_params&&)params
{
self = [super init];
if (self) {
self->params = params;
}
return self;
}
- (BOOL)isEqual:(id)other {
GGMLThreadpoolParams *rhs = (GGMLThreadpoolParams *)other;
ggml_threadpool_params rhs_params = rhs->params;
return ggml_threadpool_params_match(&params, &rhs_params);
}
- (GGMLThreadpool *)threadpool {
auto tp = ggml_threadpool_new(&params);
return [[GGMLThreadpool alloc] initWithThreadpool:tp];
}
@end

View file

@ -1,127 +1,10 @@
#import <Foundation/Foundation.h>
#import "GPTParams_Private.hpp"
#import "CPUParams_Private.hpp"
#import "GPTSampler.h"
#import "../common/common.h"
#import "ggml.h"
@implementation GGMLThreadpool {
ggml_threadpool *threadpool;
}
- (instancetype)initWithThreadpool:(ggml_threadpool *)threadpool
{
self = [super init];
if (self) {
self->threadpool = threadpool;
}
return self;
}
- (ggml_threadpool *)threadpool {
return threadpool;
}
@end
@implementation GGMLThreadpoolParams {
ggml_threadpool_params params;
}
- (BOOL)getCpuMaskAtIndex:(NSUInteger)index {
abort();
}
- (void)setCpuMask:(BOOL)value atIndex:(NSUInteger)index {
abort();
}
- (instancetype)initWithParams:(ggml_threadpool_params&&)params
{
self = [super init];
if (self) {
self->params = params;
}
return self;
}
- (BOOL)isEqual:(id)other {
GGMLThreadpoolParams *rhs = (GGMLThreadpoolParams *)other;
ggml_threadpool_params rhs_params = rhs->params;
return ggml_threadpool_params_match(&params, &rhs_params);
}
- (GGMLThreadpool *)threadpool {
auto tp = ggml_threadpool_new(&params);
return [[GGMLThreadpool alloc] initWithThreadpool:tp];
}
@end
@implementation CPUParams {
cpu_params *params;
}
- (instancetype)initWithParams:(cpu_params&)params;
{
self = [super init];
if (self) {
self->params = &params;
}
return self;
}
- (int)nThreads {
return params->n_threads;
}
- (void)setNThreads:(int)nThreads {
params->n_threads = nThreads;
}
- (BOOL)maskValid {
return params->mask_valid;
}
- (void)setMaskValid:(BOOL)maskValid {
params->mask_valid = maskValid;
}
- (GGMLSchedPriority)priority {
return GGMLSchedPriority(params->priority);
}
- (void)setPriority:(GGMLSchedPriority)priority {
params->priority = ggml_sched_priority(priority);
}
- (BOOL)strictCPU {
return params->strict_cpu;
}
- (void)setStrictCPU:(BOOL)strictCPU {
params->strict_cpu = strictCPU;
}
- (uint32_t)poll {
return params->poll;
}
- (void)setPoll:(uint32_t)poll {
params->poll = poll;
}
- (BOOL)getCpuMaskAtIndex:(NSUInteger)index {
return params->cpumask[index];
}
- (void)setCpuMask:(BOOL)value atIndex:(NSUInteger)index {
params->cpumask[index] = value;
}
- (GGMLThreadpoolParams *)ggmlThreadpoolParams {
return [[GGMLThreadpoolParams alloc] initWithParams:ggml_threadpool_params_from_cpu_params(*params)];
}
@end
@implementation GPTSamplerParams {
common_sampler_params *gpt_sampler_params;
}
@ -415,6 +298,42 @@
return antiprompts;
}
- (void)setAntiPrompts:(NSArray<NSString *> *)antiPrompts {
gpt_params.antiprompt.clear();
for (NSString *antiprompt in antiPrompts) {
gpt_params.antiprompt.push_back([antiprompt cStringUsingEncoding:NSUTF8StringEncoding]);
}
}
- (NSArray<NSString *> *)apiKeys {
auto apiKeys = [[NSMutableArray alloc] init];
for (auto& apiKey : gpt_params.api_keys) {
[apiKeys addObject:[NSString stringWithCString:apiKey.c_str() encoding:NSUTF8StringEncoding]];
}
return apiKeys;
}
- (void)setApiKeys:(NSArray<NSString *> *)apiKeys {
gpt_params.api_keys.clear();
for (NSString *apiKey in apiKeys) {
gpt_params.api_keys.push_back([apiKey cStringUsingEncoding:NSUTF8StringEncoding]);
}
}
- (NSArray<NSNumber *> *)tensorSplit {
auto tensorSplit = [[NSMutableArray alloc] init];
for (auto& tensor : gpt_params.tensor_split) {
[tensorSplit addObject:[[NSNumber alloc] initWithFloat:tensor]];
}
return tensorSplit;
}
- (void)setTensorSplit:(NSArray<NSNumber *> *)tensorSplit {
for (size_t i = 0; i < [tensorSplit count]; i++) {
gpt_params.tensor_split[i] = [tensorSplit[i] floatValue];
}
}
- (common_params&)params {
return gpt_params;
}
@ -716,4 +635,29 @@
gpt_params.input_suffix = [inputSuffix cStringUsingEncoding:NSUTF8StringEncoding];
}
- (BOOL)interactive {
return gpt_params.interactive;
}
- (void)setInteractive:(BOOL)interactive {
gpt_params.interactive = interactive;
}
- (BOOL)interactiveFirst {
return gpt_params.interactive_first;
}
- (void)setInteractiveFirst:(BOOL)interactiveFirst {
gpt_params.interactive_first = interactiveFirst;
}
- (id)copyWithZone:(NSZone *)zone {
GPTParams *copy = [[[self class] allocWithZone:zone] init];
if (copy) {
copy->gpt_params = gpt_params;
}
return copy;
}
@end

View file

@ -1,5 +1,6 @@
#import <Foundation/Foundation.h>
#import "LlamaContext_Private.hpp"
#import "GGMLThreadpool_Private.hpp"
#import "GPTParams_Private.hpp"
#import "LlamaModel_Private.hpp"
#import "LlamaBatch_Private.hpp"
@ -17,6 +18,12 @@
return self;
}
- (void)dealloc
{
[super dealloc];
llama_free(ctx);
}
- (void)attachThreadpool:(GGMLThreadpool *)threadpool
threadpoolBatch:(GGMLThreadpool *)threadpoolBatch {
llama_attach_threadpool(ctx, [threadpool threadpool], [threadpoolBatch threadpool]);

View file

@ -22,6 +22,12 @@
return self;
}
- (void)dealloc
{
[super dealloc];
llama_free_model(model);
}
- (LlamaContext *)context:(LlamaContextParams *)params {
return nil;
}

View file

@ -3,10 +3,12 @@
#import "../../common/common.h"
#import "LlamaModel_Private.hpp"
#import "LlamaContext_Private.hpp"
#import "CPUParams_Private.hpp"
#import "GPTSampler.h"
#import <OSLog/OSLog.h>
#import "ggml.h"
#import "GPTParams_Private.hpp"
#import "GGMLThreadpool_Private.hpp"
#import "LlamaBatch_Private.hpp"
@implementation BlockingLineQueue {
@ -75,8 +77,7 @@
@implementation LlamaSession {
std::vector<llama_token> embd_inp;
std::vector<common_chat_msg> chat_msgs;
GPTParams *params;
GPTSampler *smpl;
BOOL isInteracting;
bool is_antiprompt;
@ -105,13 +106,14 @@
std::vector<std::vector<llama_token>> antiprompt_ids;
BOOL need_insert_eot;
int n_ctx;
os_log_t os_log_inst;
}
- (NSString *)chat_add_and_format:(std::vector<common_chat_msg> &) chat_msgs role:(const std::string &) role content:(const std::string &) content {
common_chat_msg new_msg{role, content};
auto formatted = common_chat_format_single([self.model cModel], [params params].chat_template, chat_msgs, new_msg, role == "user");
auto formatted = common_chat_format_single([self.model cModel], [_params params].chat_template, chat_msgs, new_msg, role == "user");
chat_msgs.push_back({role, content});
os_log_debug(OS_LOG_DEFAULT, "formatted: '%s'\n", formatted.c_str());
os_log_debug(os_log_inst, "formatted: '%s'\n", formatted.c_str());
return [NSString stringWithCString:formatted.c_str() encoding:NSUTF8StringEncoding];
}
@ -131,22 +133,24 @@ static BOOL file_is_empty(NSString *path) {
- (instancetype)initWithParams:(GPTParams *)params {
self = [super init];
self->params = params;
// model = llama_init.model;
// ctx = llama_init.context;
//
// if model == nil {
// LOG_ERR("%s: error: unable to load model\n", __func__);
// return 1;
// }
//
os_log_info(OS_LOG_DEFAULT,
"%s: llama threadpool init, n_threads = %d\n",
__func__, params.cpuParams.nThreads);
self->_params = [params copy];
self->_mutableLastOutput = [[NSMutableString alloc] init];
if (params.logging) {
os_log_inst = OS_LOG_DEFAULT;
} else {
os_log_inst = OS_LOG_DISABLED;
}
if (!params.modelPath) {
[NSException raise:@"ModelFailure"
format:@"params.modelPath must be defined"];
}
os_log_info(os_log_inst,
"%s: llama threadpool init, n_threads = %ld\n",
__func__, static_cast<long>(params.cpuParams.nThreads));
if (params.embedding) {
os_log_error(OS_LOG_DEFAULT,
os_log_error(os_log_inst,
R"(************
please use the 'embedding' tool for embedding calculations
************)");
@ -154,22 +158,29 @@ static BOOL file_is_empty(NSString *path) {
}
if (params.nCtx != 0 && params.nCtx < 8) {
os_log_info(OS_LOG_DEFAULT, "minimum context size is 8, using minimum size.");
os_log_info(os_log_inst, "minimum context size is 8, using minimum size.");
params.nCtx = 8;
}
if (params.ropeFreqBase != 0) {
os_log_info(OS_LOG_DEFAULT, "changing RoPE frequency base to \(params.ropeFreqBase)");
os_log_info(os_log_inst, "changing RoPE frequency base to \(params.ropeFreqBase)");
}
if (params.ropeFreqScale != 0.0) {
os_log_info(OS_LOG_DEFAULT, "scaling RoPE frequency by \(params.ropeFreqScale)");
os_log_info(os_log_inst, "scaling RoPE frequency by \(params.ropeFreqScale)");
}
llama_backend_init();
llama_numa_init(ggml_numa_strategy(params.numaStrategy));
auto llama_init = common_init_from_params([params params]);
if (llama_init.context == nil) {
[NSException raise:@"ContextFailure"
format:@"could not load context"];
}
if (llama_init.model == nil) {
[NSException raise:@"ModelLoadFailure"
format:@"could not load model"];
}
auto tpp_batch = params.cpuParamsBatch.ggmlThreadpoolParams;
auto tpp = params.cpuParams.ggmlThreadpoolParams;
@ -179,7 +190,7 @@ static BOOL file_is_empty(NSString *path) {
if (tpp != tpp_batch) {
threadpool_batch = [tpp_batch threadpool];
if (!threadpool_batch) {
[NSException raise:@"batch threadpool create failed"
[NSException raise:@"ThreadpoolFailure"
format:@"batch threadpool create failed"];
}
@ -189,7 +200,7 @@ static BOOL file_is_empty(NSString *path) {
GGMLThreadpool *threadpool = [tpp threadpool];
if (!threadpool) {
[NSException raise:@"threadpool create failed"
[NSException raise:@"ThreadpoolFailure"
format:@"threadpool create failed"];
}
@ -200,16 +211,16 @@ static BOOL file_is_empty(NSString *path) {
n_ctx = [self.ctx nCtx];
//
if (n_ctx > n_ctx_train) {
os_log_info(OS_LOG_DEFAULT, "%s: model was trained on only %d context tokens (%d specified)\n", __func__, n_ctx_train, n_ctx);
os_log_info(os_log_inst, "%s: model was trained on only %d context tokens (%d specified)\n", __func__, n_ctx_train, n_ctx);
}
// print chat template example in conversation mode
if (params.conversation) {
if (params.enableChatTemplate) {
os_log_info(OS_LOG_DEFAULT, "%s: chat template example:\n%s\n", __func__,
os_log_info(os_log_inst, "%s: chat template example:\n%s\n", __func__,
[[self.model formatExample:params.chatTemplate] cStringUsingEncoding:NSUTF8StringEncoding]);
} else {
os_log_info(OS_LOG_DEFAULT, "%s: in-suffix/prefix is specified, chat template will be disabled\n", __func__);
os_log_info(os_log_inst, "%s: in-suffix/prefix is specified, chat template will be disabled\n", __func__);
}
}
// print system information
@ -222,11 +233,11 @@ static BOOL file_is_empty(NSString *path) {
NSFileManager *fileManager = [NSFileManager defaultManager];
if ([pathSession length] != 0) {
os_log_info(OS_LOG_DEFAULT, "%s: attempting to load saved session from '%s'\n", __func__, [pathSession cStringUsingEncoding:NSUTF8StringEncoding]);
os_log_info(os_log_inst, "%s: attempting to load saved session from '%s'\n", __func__, [pathSession cStringUsingEncoding:NSUTF8StringEncoding]);
if (![fileManager fileExistsAtPath:pathSession]) {
os_log_info(OS_LOG_DEFAULT, "%s: session file does not exist, will create.\n", __func__);
os_log_info(os_log_inst, "%s: session file does not exist, will create.\n", __func__);
} else if (file_is_empty(pathSession)) {
os_log_info(OS_LOG_DEFAULT,"%s: The session file is empty. A new session will be initialized.\n", __func__);
os_log_info(os_log_inst,"%s: The session file is empty. A new session will be initialized.\n", __func__);
} else {
// The file exists and is not empty
session_tokens.resize(n_ctx);
@ -235,7 +246,7 @@ static BOOL file_is_empty(NSString *path) {
[NSException raise:@"SessionLoadFailure" format:@"%s: failed to load session file '%s'\n", __func__, [pathSession cStringUsingEncoding:NSUTF8StringEncoding]];
}
session_tokens.resize(n_token_count_out);
os_log_info(OS_LOG_DEFAULT,"%s: loaded a session with prompt size of %d tokens\n", __func__, (int)session_tokens.size());
os_log_info(os_log_inst,"%s: loaded a session with prompt size of %d tokens\n", __func__, (int)session_tokens.size());
}
}
@ -244,7 +255,7 @@ static BOOL file_is_empty(NSString *path) {
GGML_ASSERT(![self.model addEOSToken]);
}
os_log_debug(OS_LOG_DEFAULT, "n_ctx: %d, add_bos: %d\n", n_ctx, addBOS);
os_log_debug(os_log_inst, "n_ctx: %d, add_bos: %d\n", n_ctx, addBOS);
{
@ -252,22 +263,22 @@ static BOOL file_is_empty(NSString *path) {
? [self chat_add_and_format:chat_msgs role:"system" content:[params params].prompt] // format the system prompt in conversation mode
: params.prompt;
if (params.interactiveFirst || [params.prompt length] > 0 || session_tokens.empty()) {
os_log_debug(OS_LOG_DEFAULT, "tokenize the prompt\n");
os_log_debug(os_log_inst, "tokenize the prompt\n");
embd_inp = [self.ctx tokenize:prompt addSpecial:true parseSpecial:true];
} else {
os_log_debug(OS_LOG_DEFAULT,"use session tokens\n");
os_log_debug(os_log_inst,"use session tokens\n");
embd_inp = session_tokens;
}
os_log_debug(OS_LOG_DEFAULT,"prompt: \"%s\"\n", [prompt cStringUsingEncoding:NSUTF8StringEncoding]);
os_log_debug(OS_LOG_DEFAULT,"tokens: %s\n", [self.ctx convertTokensToString:embd_inp].c_str());
os_log_debug(os_log_inst,"prompt: \"%s\"\n", [prompt cStringUsingEncoding:NSUTF8StringEncoding]);
os_log_debug(os_log_inst,"tokens: %s\n", [self.ctx convertTokensToString:embd_inp].c_str());
}
// Should not run without any tokens
if (embd_inp.empty()) {
if (addBOS) {
embd_inp.push_back([self.model tokenBOS]);
// LOG_WRN("embd_inp was considered empty and bos was added: %s\n", string_from(ctx, embd_inp).c_str());
os_log_info(os_log_inst, "embd_inp was considered empty and bos was added: %s\n", [_ctx convertTokensToString:embd_inp].c_str());
} else {
[NSException raise:@"InputEmptyError" format:@"input is empty"];
}
@ -303,13 +314,13 @@ static BOOL file_is_empty(NSString *path) {
llama_kv_cache_seq_rm([self.ctx cContext], -1, n_matching_session_tokens, -1);
}
//
// os_log_debug(OS_LOG_DEFAULT, "recalculate the cached logits (check): embd_inp.size() %zu, n_matching_session_tokens %zu, embd_inp.size() %zu, session_tokens.size() %zu\n",
// os_log_debug(os_log_inst, "recalculate the cached logits (check): embd_inp.size() %zu, n_matching_session_tokens %zu, embd_inp.size() %zu, session_tokens.size() %zu\n",
// embd_inp.size(), n_matching_session_tokens, embd_inp.size(), session_tokens.size());
//
// if we will use the cache for the full prompt without reaching the end of the cache, force
// reevaluation of the last token to recalculate the cached logits
if (!embd_inp.empty() && n_matching_session_tokens == embd_inp.size() && session_tokens.size() > embd_inp.size()) {
// os_log_debug(OS_LOG_DEFAULT, "recalculate the cached logits (do): session_tokens.resize( %zu )\n", embd_inp.size() - 1);
// os_log_debug(os_log_inst, "recalculate the cached logits (do): session_tokens.resize( %zu )\n", embd_inp.size() - 1);
session_tokens.resize(embd_inp.size() - 1);
}
@ -331,22 +342,21 @@ static BOOL file_is_empty(NSString *path) {
}
if (params.verbosePrompt) {
// LOG_INF("%s: prompt: '%s'\n", __func__, params.prompt.c_str());
os_log_info(os_log_inst,
"%s: prompt: '%s'\n", __func__, [params.prompt cStringUsingEncoding:NSUTF8StringEncoding]);
// LOG_INF("%s: number of tokens in prompt = %zu\n", __func__, embd_inp.size());
for (int i = 0; i < (int) embd_inp.size(); i++) {
os_log_info(OS_LOG_DEFAULT, "%6d -> '%s'\n", embd_inp[i],
os_log_info(os_log_inst, "%6d -> '%s'\n", embd_inp[i],
[[self.ctx tokenToPiece:embd_inp[i]] cStringUsingEncoding:NSUTF8StringEncoding]);
}
if (params.nKeep > addBOS) {
// LOG_INF("%s: static prompt based on n_keep: '", __func__);
for (int i = 0; i < params.nKeep; i++) {
os_log_debug(OS_LOG_DEFAULT, "%s",
os_log_debug(os_log_inst, "%s",
[[self.ctx tokenToPiece:embd_inp[i]] cStringUsingEncoding:NSUTF8StringEncoding]);
}
// LOG("'\n");
}
// LOG_INF("\n");
}
//
// // ctrl+C handling
@ -366,55 +376,55 @@ static BOOL file_is_empty(NSString *path) {
// }
//
if (params.interactive) {
os_log_info(OS_LOG_DEFAULT, "%s: interactive mode on.\n", __func__);
os_log_info(os_log_inst, "%s: interactive mode on.\n", __func__);
if ([params.antiPrompts count] > 0) {
for (NSString *antiprompt in params.antiPrompts) {
os_log_info(OS_LOG_DEFAULT, "Reverse prompt: '%s'\n", [antiprompt cStringUsingEncoding:NSUTF8StringEncoding]);
os_log_info(os_log_inst, "Reverse prompt: '%s'\n", [antiprompt cStringUsingEncoding:NSUTF8StringEncoding]);
if (params.verbosePrompt) {
auto tmp = [_ctx tokenize:antiprompt
addSpecial:false
parseSpecial:true];
for (int i = 0; i < (int) tmp.size(); i++) {
os_log_info(OS_LOG_DEFAULT, "%6d -> '%s'\n", tmp[i], [[self.ctx tokenToPiece:tmp[i]] cStringUsingEncoding:NSUTF8StringEncoding]);
os_log_info(os_log_inst, "%6d -> '%s'\n", tmp[i], [[self.ctx tokenToPiece:tmp[i]] cStringUsingEncoding:NSUTF8StringEncoding]);
}
}
}
}
if (params.inputPrefixBOS) {
os_log_info(OS_LOG_DEFAULT, "Input prefix with BOS\n");
os_log_info(os_log_inst, "Input prefix with BOS\n");
}
if ([params.inputPrefix length] > 0) {
os_log_info(OS_LOG_DEFAULT, "Input prefix: '%s'\n", [params.inputPrefix cStringUsingEncoding:NSUTF8StringEncoding]);
os_log_info(os_log_inst, "Input prefix: '%s'\n", [params.inputPrefix cStringUsingEncoding:NSUTF8StringEncoding]);
if (params.verbosePrompt) {
auto tmp = [_ctx tokenize:params.inputPrefix addSpecial:true parseSpecial:true];
for (int i = 0; i < (int) tmp.size(); i++) {
os_log_info(OS_LOG_DEFAULT, "%6d -> '%s'\n",
os_log_info(os_log_inst, "%6d -> '%s'\n",
tmp[i], [[self.ctx tokenToPiece:tmp[i]] cStringUsingEncoding:NSUTF8StringEncoding]);
}
}
}
if ([params.inputSuffix length] > 0) {
os_log_info(OS_LOG_DEFAULT, "Input suffix: '%s'\n", [params.inputSuffix cStringUsingEncoding:NSUTF8StringEncoding]);
os_log_info(os_log_inst, "Input suffix: '%s'\n", [params.inputSuffix cStringUsingEncoding:NSUTF8StringEncoding]);
if (params.verbosePrompt) {
auto tmp = [_ctx tokenize:params.inputSuffix addSpecial:false parseSpecial:true];
for (int i = 0; i < (int) tmp.size(); i++) {
os_log_info(OS_LOG_DEFAULT, "%6d -> '%s'\n",
os_log_info(os_log_inst, "%6d -> '%s'\n",
tmp[i], [[self.ctx tokenToPiece:tmp[i]] cStringUsingEncoding:NSUTF8StringEncoding]);
}
}
}
}
smpl = [[GPTSampler alloc] init:_model gptSamplerParams:[params samplerParams]];
if (!smpl) {
_smpl = [[GPTSampler alloc] init:_model gptSamplerParams:[params samplerParams]];
if (!_smpl) {
[NSException raise:@"SamplingFailure" format:@"failed to initialize sampling subsystem"];
}
os_log_info(OS_LOG_DEFAULT, "sampler seed: %u\n", [smpl seed]);
os_log_info(os_log_inst, "sampler seed: %u\n", [_smpl seed]);
// LOG_INF("sampler params: \n%s\n", sparams.print().c_str());
// LOG_INF("sampler chain: %s\n", gpt_sampler_print(smpl).c_str());
//
@ -431,7 +441,7 @@ static BOOL file_is_empty(NSString *path) {
GGML_ASSERT(ga_w % ga_n == 0 && "grp_attn_w must be a multiple of grp_attn_n"); // NOLINT
//GGML_ASSERT(n_ctx_train % ga_w == 0 && "n_ctx_train must be a multiple of grp_attn_w"); // NOLINT
//GGML_ASSERT(n_ctx >= n_ctx_train * ga_n && "n_ctx must be at least n_ctx_train * grp_attn_n"); // NOLINT
os_log_info(OS_LOG_DEFAULT, "self-extend: n_ctx_train = %d, grp_attn_n = %ld, grp_attn_w = %ld\n", n_ctx_train, static_cast<long>(ga_n), static_cast<long>(ga_w));
os_log_info(os_log_inst, "self-extend: n_ctx_train = %d, grp_attn_n = %ld, grp_attn_w = %ld\n", n_ctx_train, static_cast<long>(ga_n), static_cast<long>(ga_w));
}
if (params.interactive) {
@ -454,14 +464,6 @@ static BOOL file_is_empty(NSString *path) {
need_to_save_session = [pathSession length] > 0 && n_matching_session_tokens < embd_inp.size();
n_remain = params.nPredict;
// // the first thing we will do is to output the prompt, so set color accordingly
// console::set_display(console::prompt);
// display = params.display_prompt;
//
antiprompt_ids.reserve([params.antiPrompts count]);
for (NSString *antiprompt in params.antiPrompts) {
antiprompt_ids.emplace_back([self.ctx tokenize:antiprompt addSpecial:false parseSpecial:true]);
@ -486,8 +488,13 @@ static BOOL file_is_empty(NSString *path) {
return self;
}
// MARK: LastOutput
- (NSString *)lastOutput {
return [_mutableLastOutput copy];
}
- (void)start:(BlockingLineQueue *)queue {
while ((n_remain != 0 && !is_antiprompt) || params.interactive) {
while ((n_remain != 0 && !is_antiprompt) || _params.interactive) {
// predict
if (!embd.empty()) {
// Note: (n_ctx - 4) here is to match the logic for commandline prompt handling via
@ -500,42 +507,42 @@ static BOOL file_is_empty(NSString *path) {
embd.resize(max_embd_size);
// console::set_display(console::error);
os_log_error(OS_LOG_DEFAULT, "<<input too long: skipped %d token%s>>", skipped_tokens, skipped_tokens != 1 ? "s" : "");
os_log_error(os_log_inst, "<<input too long: skipped %d token%s>>", skipped_tokens, skipped_tokens != 1 ? "s" : "");
// console::set_display(console::reset);
}
if (params.grpAttnN == 1) {
if (_params.grpAttnN == 1) {
// infinite text generation via context shifting
// if we run out of context:
// - take the n_keep first tokens from the original prompt (via n_past)
// - take half of the last (n_ctx - n_keep) tokens and recompute the logits in batches
if (n_past + (int) embd.size() >= [_ctx nCtx]) {
if (!params.ctxShift) {
os_log_debug(OS_LOG_DEFAULT, "\n\n%s: context full and context shift is disabled => stopping\n", __func__);
if (!_params.ctxShift) {
os_log_debug(os_log_inst, "\n\n%s: context full and context shift is disabled => stopping\n", __func__);
break;
} else {
if (params.nPredict == -2) {
os_log_debug(OS_LOG_DEFAULT, "\n\n%s: context full and n_predict == -%d => stopping\n", __func__, params.nPredict);
if (_params.nPredict == -2) {
os_log_debug(os_log_inst, "\n\n%s: context full and n_predict == -%d => stopping\n", __func__, _params.nPredict);
break;
}
const int n_left = n_past - params.nKeep;
const int n_left = n_past - _params.nKeep;
const int n_discard = n_left/2;
os_log_debug(OS_LOG_DEFAULT, "context full, swapping: n_past = %d, n_left = %d, n_ctx = %lu, n_keep = %d, n_discard = %d\n",
n_past, n_left, static_cast<unsigned long>([_ctx nCtx]), params.nKeep, n_discard);
os_log_debug(os_log_inst, "context full, swapping: n_past = %d, n_left = %d, n_ctx = %lu, n_keep = %d, n_discard = %d\n",
n_past, n_left, static_cast<unsigned long>([_ctx nCtx]), _params.nKeep, n_discard);
llama_kv_cache_seq_rm ([self.ctx cContext], 0, params.nKeep , params.nKeep + n_discard);
llama_kv_cache_seq_add([self.ctx cContext], 0, params.nKeep + n_discard, n_past, -n_discard);
llama_kv_cache_seq_rm ([self.ctx cContext], 0, _params.nKeep , _params.nKeep + n_discard);
llama_kv_cache_seq_add([self.ctx cContext], 0, _params.nKeep + n_discard, n_past, -n_discard);
n_past -= n_discard;
os_log_debug(OS_LOG_DEFAULT, "after swap: n_past = %d\n", n_past);
os_log_debug(os_log_inst, "after swap: n_past = %d\n", n_past);
os_log_debug(OS_LOG_DEFAULT, "embd: %s\n", [self.ctx convertTokensToString:embd].c_str());
os_log_debug(os_log_inst, "embd: %s\n", [self.ctx convertTokensToString:embd].c_str());
os_log_debug(OS_LOG_DEFAULT, "clear session path\n");
os_log_debug(os_log_inst, "clear session path\n");
[pathSession setString:@""];
}
}
@ -546,10 +553,10 @@ static BOOL file_is_empty(NSString *path) {
const int bd = (ga_w/ga_n)*(ga_n - 1);
const int dd = (ga_w/ga_n) - ib*bd - ga_w;
os_log_debug(OS_LOG_DEFAULT, "\n");
os_log_debug(OS_LOG_DEFAULT, "shift: [%6ld, %6d] + %6d -> [%6ld, %6d]\n", static_cast<long>(ga_i), n_past, ib*bd, static_cast<long>(ga_i + ib*bd), n_past + ib*bd);
os_log_debug(OS_LOG_DEFAULT, "div: [%6ld, %6ld] / %6ld -> [%6ld, %6ld]\n", static_cast<long>(ga_i + ib*bd), static_cast<long>(ga_i + ib*bd + ga_w), static_cast<long>(ga_n), static_cast<long>((ga_i + ib*bd)/ga_n), static_cast<long>((ga_i + ib*bd + ga_w)/ga_n));
os_log_debug(OS_LOG_DEFAULT, "shift: [%6ld, %6d] + %6d -> [%6ld, %6d]\n", static_cast<long>(ga_i + ib*bd + ga_w), n_past + ib*bd, dd, static_cast<long>(ga_i + ib*bd + ga_w + dd), n_past + ib*bd + dd);
os_log_debug(os_log_inst, "\n");
os_log_debug(os_log_inst, "shift: [%6ld, %6d] + %6d -> [%6ld, %6d]\n", static_cast<long>(ga_i), n_past, ib*bd, static_cast<long>(ga_i + ib*bd), n_past + ib*bd);
os_log_debug(os_log_inst, "div: [%6ld, %6ld] / %6ld -> [%6ld, %6ld]\n", static_cast<long>(ga_i + ib*bd), static_cast<long>(ga_i + ib*bd + ga_w), static_cast<long>(ga_n), static_cast<long>((ga_i + ib*bd)/ga_n), static_cast<long>((ga_i + ib*bd + ga_w)/ga_n));
os_log_debug(os_log_inst, "shift: [%6ld, %6d] + %6d -> [%6ld, %6d]\n", static_cast<long>(ga_i + ib*bd + ga_w), n_past + ib*bd, dd, static_cast<long>(ga_i + ib*bd + ga_w + dd), n_past + ib*bd + dd);
[self.ctx kvCacheSeqAdd:0 p0:ga_i p1:n_past delta:ib*bd];
[self.ctx kvCacheSeqDiv:0 p0:ga_i + ib*bd p1:ga_i + ib*bd + ga_w delta:ga_n];
@ -559,7 +566,7 @@ static BOOL file_is_empty(NSString *path) {
ga_i += ga_w/ga_n;
os_log_debug(OS_LOG_DEFAULT, "\nn_past_old = %d, n_past = %d, ga_i = %ld\n\n", n_past + bd, n_past, static_cast<long>(ga_i));
os_log_debug(os_log_inst, "\nn_past_old = %d, n_past = %d, ga_i = %ld\n\n", n_past + bd, n_past, static_cast<long>(ga_i));
}
}
@ -585,13 +592,13 @@ static BOOL file_is_empty(NSString *path) {
}
}
for (int i = 0; i < (int) embd.size(); i += params.nBatch) {
for (int i = 0; i < (int) embd.size(); i += _params.nBatch) {
int n_eval = (int) embd.size() - i;
if (n_eval > params.nBatch) {
n_eval = params.nBatch;
if (n_eval > _params.nBatch) {
n_eval = _params.nBatch;
}
os_log_debug(OS_LOG_DEFAULT, "eval: %s\n", [self.ctx convertTokensToString:embd].c_str());
os_log_debug(os_log_inst, "eval: %s\n", [self.ctx convertTokensToString:embd].c_str());
if ([self.ctx decode:[[LlamaBatch alloc] initWithBatch:llama_batch_get_one(&embd[i], n_eval)] ]) {
@ -600,10 +607,10 @@ static BOOL file_is_empty(NSString *path) {
n_past += n_eval;
os_log_debug(OS_LOG_DEFAULT, "n_past = %d\n", n_past);
os_log_debug(os_log_inst, "n_past = %d\n", n_past);
// Display total tokens alongside total time
if (params.nPrint > 0 && n_past % params.nPrint == 0) {
os_log_debug(OS_LOG_DEFAULT, "\n\033[31mTokens consumed so far = %d / %lu \033[0m\n", n_past, static_cast<unsigned long>([self.ctx nCtx]));
if (_params.nPrint > 0 && n_past % _params.nPrint == 0) {
os_log_debug(os_log_inst, "\n\033[31mTokens consumed so far = %d / %lu \033[0m\n", n_past, static_cast<unsigned long>([self.ctx nCtx]));
}
}
@ -617,19 +624,19 @@ static BOOL file_is_empty(NSString *path) {
if ((int) embd_inp.size() <= n_consumed && !isInteracting) {
// optionally save the session on first sample (for faster prompt loading next time)
if ([pathSession length] > 0 && need_to_save_session && !params.promptCacheRO) {
if ([pathSession length] > 0 && need_to_save_session && !_params.promptCacheRO) {
need_to_save_session = false;
[self.ctx saveStateFile:pathSession tokens:session_tokens.data() nTokenCount:session_tokens.size()];
// llama_state_save_file(ctx, path_session.c_str(), session_tokens.data(), session_tokens.size());
os_log_debug(OS_LOG_DEFAULT, "saved session to %s\n", [pathSession cStringUsingEncoding:NSUTF8StringEncoding]);
os_log_debug(os_log_inst, "saved session to %s\n", [pathSession cStringUsingEncoding:NSUTF8StringEncoding]);
}
const llama_token idToken = [smpl sample:self.ctx index:-1];
const llama_token idToken = [_smpl sample:self.ctx index:-1];
[smpl accept:idToken acceptGrammar:true];
[_smpl accept:idToken acceptGrammar:true];
// os_log_debug(OS_LOG_DEFAULT, "last: %s\n", string_from(ctx, smpl->prev.to_vector()).c_str());
// os_log_debug(os_log_inst, "last: %s\n", string_from(ctx, smpl->prev.to_vector()).c_str());
embd.push_back(idToken);
@ -639,19 +646,19 @@ static BOOL file_is_empty(NSString *path) {
// decrement remaining sampling budget
--n_remain;
os_log_debug(OS_LOG_DEFAULT, "n_remain: %d\n", n_remain);
os_log_debug(os_log_inst, "n_remain: %d\n", n_remain);
} else {
// some user input remains from prompt or interaction, forward it to processing
os_log_debug(OS_LOG_DEFAULT, "embd_inp.size(): %d, n_consumed: %d\n", (int) embd_inp.size(), n_consumed);
os_log_debug(os_log_inst, "embd_inp.size(): %d, n_consumed: %d\n", (int) embd_inp.size(), n_consumed);
while ((int) embd_inp.size() > n_consumed) {
embd.push_back(embd_inp[n_consumed]);
// push the prompt in the sampling context in order to apply repetition penalties later
// for the prompt, we don't apply grammar rules
[smpl accept:embd_inp[n_consumed] acceptGrammar:false];
[_smpl accept:embd_inp[n_consumed] acceptGrammar:false];
++n_consumed;
if ((int) embd.size() >= params.nBatch) {
if ((int) embd.size() >= _params.nBatch) {
break;
}
}
@ -662,10 +669,10 @@ static BOOL file_is_empty(NSString *path) {
// std::cout<< "DISPLAYING TEXT" << std::endl;
for (auto idToken : embd) {
NSString *token_str = [self.ctx tokenToPiece:idToken special:params.special];
NSString *token_str = [self.ctx tokenToPiece:idToken special:_params.special];
// Console/Stream Output
os_log_info(OS_LOG_DEFAULT, "%s", [token_str cStringUsingEncoding:NSUTF8StringEncoding]);
os_log_info(os_log_inst, "%s", [token_str cStringUsingEncoding:NSUTF8StringEncoding]);
// Record Displayed Tokens To Log
// Note: Generated tokens are created one by one hence this check
@ -678,6 +685,9 @@ static BOOL file_is_empty(NSString *path) {
output_tokens.push_back(idToken);
output_ss << [token_str cStringUsingEncoding:NSUTF8StringEncoding];
last_output_ss << [token_str cStringUsingEncoding:NSUTF8StringEncoding];
[self willChangeValueForKey:@"lastOutput"];
[_mutableLastOutput appendString:token_str];
[self didChangeValueForKey:@"lastOutput"];
}
}
@ -698,23 +708,23 @@ static BOOL file_is_empty(NSString *path) {
// if not currently processing queued inputs;
if ((int) embd_inp.size() <= n_consumed) {
// check for reverse prompt in the last n_prev tokens
if ([params.antiPrompts count] > 0) {
if ([_params.antiPrompts count] > 0) {
const int n_prev = 32;
NSString *last_output = [smpl previousString:self.ctx n:n_prev];
NSString *last_output = [_smpl previousString:self.ctx n:n_prev];
is_antiprompt = false;
// Check if each of the reverse prompts appears at the end of the output.
// If we're not running interactively, the reverse prompt might be tokenized with some following characters
// so we'll compensate for that by widening the search window a bit.
for (NSString *antiprompt in params.antiPrompts) {
size_t extra_padding = params.interactive ? 0 : 2;
for (NSString *antiprompt in _params.antiPrompts) {
size_t extra_padding = _params.interactive ? 0 : 2;
size_t search_start_pos = [last_output length] > static_cast<size_t>([antiprompt length] + extra_padding)
? [last_output length] - static_cast<size_t>([antiprompt length] + extra_padding)
: 0;
// TODO: Check if correct
if ([last_output rangeOfString:antiprompt options:0 range:NSMakeRange(search_start_pos, last_output.length - search_start_pos)].location != NSNotFound) {
if (params.interactive) {
if (_params.interactive) {
isInteracting = true;
}
is_antiprompt = true;
@ -723,10 +733,10 @@ static BOOL file_is_empty(NSString *path) {
}
// check for reverse prompt using special tokens
llama_token last_token = [smpl last];
llama_token last_token = [_smpl last];
for (std::vector<llama_token> ids : antiprompt_ids) {
if (ids.size() == 1 && last_token == ids[0]) {
if (params.interactive) {
if (_params.interactive) {
isInteracting = true;
}
is_antiprompt = true;
@ -735,25 +745,25 @@ static BOOL file_is_empty(NSString *path) {
}
if (is_antiprompt) {
os_log_debug(OS_LOG_DEFAULT, "found antiprompt: %s\n", [last_output cStringUsingEncoding:NSUTF8StringEncoding]);
os_log_debug(os_log_inst, "found antiprompt: %s\n", [last_output cStringUsingEncoding:NSUTF8StringEncoding]);
}
}
// deal with end of generation tokens in interactive mode
if ([self.model tokenIsEOG:[smpl last]]) {
os_log_debug(OS_LOG_DEFAULT, "found an EOG token\n");
if ([self.model tokenIsEOG:[_smpl last]]) {
os_log_debug(os_log_inst, "found an EOG token\n");
if (params.interactive) {
if ([[params antiPrompts] count] > 0) {
if (_params.interactive) {
if ([[_params antiPrompts] count] > 0) {
// tokenize and inject first reverse prompt
const auto first_antiprompt = [self.ctx tokenize:params.antiPrompts[0] addSpecial:false parseSpecial:true];
const auto first_antiprompt = [self.ctx tokenize:_params.antiPrompts[0] addSpecial:false parseSpecial:true];
embd_inp.insert(embd_inp.end(), first_antiprompt.begin(), first_antiprompt.end());
is_antiprompt = true;
}
if (params.enableChatTemplate) {
if (_params.enableChatTemplate) {
[self chat_add_and_format:chat_msgs
role:"assistant"
content:assistant_ss.str()];
@ -764,32 +774,32 @@ static BOOL file_is_empty(NSString *path) {
}
// if current token is not EOG, we add it to current assistant message
if (params.conversation) {
const auto idToken = [smpl last];
if (_params.conversation) {
const auto idToken = [_smpl last];
assistant_ss << [[self.ctx tokenToPiece:idToken special:false] cStringUsingEncoding:NSUTF8StringEncoding];
}
if (n_past > 0 && isInteracting) {
os_log_debug(OS_LOG_DEFAULT, "waiting for user input\n");
os_log_debug(os_log_inst, "waiting for user input\n");
if (params.conversation) {
if (_params.conversation) {
// osLog_("\n> ");
}
if (params.inputPrefixBOS) {
os_log_debug(OS_LOG_DEFAULT, "adding input prefix BOS token\n");
if (_params.inputPrefixBOS) {
os_log_debug(os_log_inst, "adding input prefix BOS token\n");
embd_inp.push_back([self.model tokenBOS]);
}
std::string buffer;
if ([params.inputPrefix length] > 0 && !params.conversation) {
os_log_debug(OS_LOG_DEFAULT, "appending input prefix: '%s'\n", [params.inputPrefix cStringUsingEncoding:NSUTF8StringEncoding]);
os_log_info(OS_LOG_DEFAULT, "%s", [params.inputPrefix cStringUsingEncoding:NSUTF8StringEncoding]);
if ([_params.inputPrefix length] > 0 && !_params.conversation) {
os_log_debug(os_log_inst, "appending input prefix: '%s'\n", [_params.inputPrefix cStringUsingEncoding:NSUTF8StringEncoding]);
os_log_info(os_log_inst, "%s", [_params.inputPrefix cStringUsingEncoding:NSUTF8StringEncoding]);
}
// color user input only
// console::set_display(console::user_input);
display = params.displayPrompt;
display = _params.displayPrompt;
std::string line;
// bool another_line = true;
@ -806,8 +816,12 @@ static BOOL file_is_empty(NSString *path) {
auto str = last_output_ss.str();
last_output_ss.str("");
[queue addOutputLine:[NSString stringWithCString:str.c_str() encoding:NSUTF8StringEncoding]];
[self willChangeValueForKey:@"lastOutput"];
_mutableLastOutput = [[NSMutableString alloc] init];
[self didChangeValueForKey:@"lastOutput"];
}
buffer = [[queue inputLine] cStringUsingEncoding:NSUTF8StringEncoding];
// do {
// another_line = console::readline(line, params.multiline_input);
@ -822,34 +836,34 @@ static BOOL file_is_empty(NSString *path) {
// Entering a empty line lets the user pass control back
if (buffer.length() > 1) {
// append input suffix if any
if ([params.inputSuffix length] > 0 && !params.conversation) {
os_log_debug(OS_LOG_DEFAULT, "appending input suffix: '%s'\n", [params.inputSuffix cStringUsingEncoding:NSUTF8StringEncoding]);
os_log_info(OS_LOG_DEFAULT, "%s", [params.inputSuffix cStringUsingEncoding:NSUTF8StringEncoding]);
if ([[self params].inputSuffix length] > 0 && !_params.conversation) {
os_log_debug(os_log_inst, "appending input suffix: '%s'\n", [_params.inputSuffix cStringUsingEncoding:NSUTF8StringEncoding]);
os_log_info(os_log_inst, "%s", [_params.inputSuffix cStringUsingEncoding:NSUTF8StringEncoding]);
}
os_log_debug(OS_LOG_DEFAULT, "buffer: '%s'\n", buffer.c_str());
os_log_debug(os_log_inst, "buffer: '%s'\n", buffer.c_str());
const size_t original_size = embd_inp.size();
if (params.escapeSequences) {
if (_params.escapeSequences) {
string_process_escapes(buffer);
}
bool format_chat = params.conversation && params.enableChatTemplate;
bool format_chat = _params.conversation && _params.enableChatTemplate;
std::string user_inp = format_chat
? [[self chat_add_and_format:chat_msgs role:"user" content:std::move(buffer)] cStringUsingEncoding:NSUTF8StringEncoding]
: std::move(buffer);
// TODO: one inconvenient of current chat template implementation is that we can't distinguish between user input and special tokens (prefix/postfix)
const auto line_pfx = [self.ctx tokenize:params.inputPrefix addSpecial:false parseSpecial:true];
const auto line_pfx = [self.ctx tokenize:_params.inputPrefix addSpecial:false parseSpecial:true];
const auto line_inp = [self.ctx tokenize:[NSString stringWithCString:user_inp.c_str()
encoding:NSUTF8StringEncoding]
addSpecial:false
parseSpecial:format_chat];
const auto line_sfx = [self.ctx tokenize:params.inputSuffix
const auto line_sfx = [self.ctx tokenize:_params.inputSuffix
addSpecial:false
parseSpecial:true];
os_log_debug(OS_LOG_DEFAULT, "input tokens: %s\n", [self.ctx convertTokensToString:line_inp].c_str());
os_log_debug(os_log_inst, "input tokens: %s\n", [self.ctx convertTokensToString:line_inp].c_str());
// if user stop generation mid-way, we must add EOT to finish model's last response
if (need_insert_eot && format_chat) {
@ -872,9 +886,9 @@ static BOOL file_is_empty(NSString *path) {
assistant_ss.str("");
n_remain -= line_inp.size();
os_log_debug(OS_LOG_DEFAULT, "n_remain: %d\n", n_remain);
os_log_debug(os_log_inst, "n_remain: %d\n", n_remain);
} else {
os_log_debug(OS_LOG_DEFAULT, "empty line, passing control back\n");
os_log_debug(os_log_inst, "empty line, passing control back\n");
}
input_echo = false; // do not echo this again
@ -882,22 +896,22 @@ static BOOL file_is_empty(NSString *path) {
if (n_past > 0) {
if (isInteracting) {
[smpl reset];
[_smpl reset];
}
isInteracting = false;
}
}
// end of generation
if (!embd.empty() && [self.model tokenIsEOG:embd.back()] && !(params.interactive)) {
os_log_info(OS_LOG_DEFAULT, " [end of text]\n");
if (!embd.empty() && [self.model tokenIsEOG:embd.back()] && !(_params.interactive)) {
os_log_info(os_log_inst, " [end of text]\n");
break;
}
// In interactive mode, respect the maximum number of tokens and drop back to user input when reached.
// We skip this logic when n_predict == -1 (infinite) or -2 (stop at context size).
if (params.interactive && n_remain <= 0 && params.nPredict >= 0) {
n_remain = params.nPredict;
if (_params.interactive && n_remain <= 0 && _params.nPredict >= 0) {
n_remain = _params.nPredict;
isInteracting = true;
}
}

30
objc/include/CPUParams.h Normal file
View file

@ -0,0 +1,30 @@
#ifndef CPUParams_h
#define CPUParams_h
typedef NS_ENUM(NSUInteger, GGMLSchedPriority);
@class GGMLThreadpoolParams;
@interface CPUParams : NSObject
/// Number of threads to use
@property (nonatomic, assign) NSInteger nThreads;
/// Default: any CPU
@property (nonatomic, assign) BOOL maskValid;
/// Scheduling priority
@property (nonatomic, assign) GGMLSchedPriority priority;
/// Use strict CPU placement
@property (nonatomic, assign) BOOL strictCPU;
/// Polling (busywait) level (0 - no polling, 100 - mostly polling)
@property (nonatomic, assign) NSUInteger poll;
// Custom methods to access or manipulate the cpumask array
- (BOOL)getCpuMaskAtIndex:(NSUInteger)index;
- (void)setCpuMask:(BOOL)value atIndex:(NSUInteger)index;
- (GGMLThreadpoolParams *)ggmlThreadpoolParams;
@end
#endif /* Header_h */

View file

@ -0,0 +1,15 @@
#ifndef CPUParams_Private_hpp
#define CPUParams_Private_hpp
#import "CPUParams.h"
#import "../../common/common.h"
@interface CPUParams() {
cpu_params *params;
}
- (instancetype)initWithParams:(cpu_params&)params;
@end
#endif /* CPUParams_Private_hpp */

View file

@ -0,0 +1,29 @@
#ifndef GGMLThreadpool_h
#define GGMLThreadpool_h
typedef NS_ENUM(NSUInteger, GGMLSchedPriority) {
GGMLSchedPriorityNormal = 0, // Normal priority
GGMLSchedPriorityMedium = 1, // Medium priority
GGMLSchedPriorityHigh = 2, // High priority
GGMLSchedPriorityRealtime = 3 // Realtime priority
};
@interface GGMLThreadpool : NSObject
@end
@interface GGMLThreadpoolParams : NSObject
@property (nonatomic, assign) int nThreads;
@property (nonatomic, assign) GGMLSchedPriority priority;
@property (nonatomic, assign) uint32_t poll;
@property (nonatomic, assign) BOOL strictCPU;
@property (nonatomic, assign) BOOL paused;
- (BOOL)getCpuMaskAtIndex:(NSUInteger)index;
- (void)setCpuMask:(BOOL)value atIndex:(NSUInteger)index;
- (GGMLThreadpool *)threadpool;
@end
#endif /* GGMLThreadpool_h */

View file

@ -0,0 +1,22 @@
#ifndef GGMLThreadpool_Private_hpp
#define GGMLThreadpool_Private_hpp
#import "GGMLThreadpool.h"
#import "../../common/common.h"
@interface GGMLThreadpool() {
ggml_threadpool *threadpool;
}
- (instancetype)initWithThreadpool:(ggml_threadpool *)threadpool;
- (ggml_threadpool *)threadpool;
@end
@interface GGMLThreadpoolParams()
- (instancetype)initWithParams:(ggml_threadpool_params&&)params;
@end
#endif /* GGMLThreadpool_Private_hpp */

View file

@ -4,48 +4,7 @@
@class LlamaModelParams;
@class LlamaContextParams;
@class GGMLThreadpool;
// Define the ggml_sched_priority enum
typedef NS_ENUM(NSInteger, GGMLSchedPriority) {
GGMLSchedPriorityNormal = 0, // Normal priority
GGMLSchedPriorityMedium = 1, // Medium priority
GGMLSchedPriorityHigh = 2, // High priority
GGMLSchedPriorityRealtime = 3 // Realtime priority
};
@interface GGMLThreadpoolParams : NSObject
@property (nonatomic, assign) int nThreads;
@property (nonatomic, assign) GGMLSchedPriority priority;
@property (nonatomic, assign) uint32_t poll;
@property (nonatomic, assign) BOOL strictCPU;
@property (nonatomic, assign) BOOL paused;
// Custom access methods for the cpumask array
- (BOOL)getCpuMaskAtIndex:(NSUInteger)index;
- (void)setCpuMask:(BOOL)value atIndex:(NSUInteger)index;
- (GGMLThreadpool *)threadpool;
@end
@interface GGMLThreadpool : NSObject
@end
@interface CPUParams : NSObject
// Properties
@property (nonatomic, assign) int nThreads;
@property (nonatomic, assign) BOOL maskValid;
@property (nonatomic, assign) GGMLSchedPriority priority;
@property (nonatomic, assign) BOOL strictCPU;
@property (nonatomic, assign) uint32_t poll;
// Custom methods to access or manipulate the cpumask array
- (BOOL)getCpuMaskAtIndex:(NSUInteger)index;
- (void)setCpuMask:(BOOL)value atIndex:(NSUInteger)index;
- (GGMLThreadpoolParams *)ggmlThreadpoolParams;
@end
@class CPUParams;
@interface GPTSamplerParams : NSObject
@ -72,13 +31,10 @@ typedef NS_ENUM(NSInteger, GGMLSchedPriority) {
@property (nonatomic, assign) BOOL penalizeNl;
@property (nonatomic, assign) BOOL ignoreEos;
@property (nonatomic, assign) BOOL noPerf;
@property (nonatomic, strong) NSArray<NSNumber *> *samplers;
@property (nonatomic, copy) NSString *grammar;
@property (nonatomic, strong) NSArray<NSNumber *> *logitBias;
// Arrays and Strings
@property (nonatomic, strong) NSArray<NSNumber *> *samplers; // Samplers mapped to NSArray of NSNumber (for enums)
@property (nonatomic, copy) NSString *grammar; // Grammar as NSString
@property (nonatomic, strong) NSArray<NSNumber *> *logitBias; // Logit biases mapped to NSArray of NSNumber
// Method to print the parameters into a string
- (NSString *)print;
@end
@ -98,7 +54,7 @@ typedef NS_ENUM(NSInteger, GGMLSchedPriority) {
@property (nonatomic, assign) int32_t nGpuLayers;
@property (nonatomic, assign) int32_t nGpuLayersDraft;
@property (nonatomic, assign) int32_t mainGpu;
@property (nonatomic, strong) NSMutableArray<NSNumber *> *tensorSplit; // Fixed-size array, stays the same
@property (nonatomic, strong) NSArray<NSNumber *> *tensorSplit;
@property (nonatomic, assign) int32_t grpAttnN;
@property (nonatomic, assign) int32_t grpAttnW;
@property (nonatomic, assign) int32_t nPrint;
@ -111,13 +67,11 @@ typedef NS_ENUM(NSInteger, GGMLSchedPriority) {
@property (nonatomic, assign) int32_t yarnOrigCtx;
@property (nonatomic, assign) float defragThold;
// You need to replace your C++ struct "cpu_params" with an Objective-C class or struct accordingly
@property (nonatomic, strong) CPUParams *cpuParams;
@property (nonatomic, strong) CPUParams *cpuParamsBatch;
@property (nonatomic, strong) CPUParams *draftCpuParams;
@property (nonatomic, strong) CPUParams *draftCpuParamsBatch;
// Callbacks (assuming they are blocks in Objective-C)
@property (nonatomic, copy) void (^cbEval)(void *);
@property (nonatomic, assign) void *cbEvalUserData;
@ -149,12 +103,10 @@ typedef NS_ENUM(NSInteger, GGMLSchedPriority) {
@property (nonatomic, copy) NSString *logitsFile;
@property (nonatomic, copy) NSString *rpcServers;
// Arrays in Objective-C are represented with `NSArray`
@property (nonatomic, strong) NSArray<NSString *> *inputFiles;
@property (nonatomic, strong) NSArray<NSString *> *antiPrompts;
@property (nonatomic, strong) NSArray *kvOverrides;
// Boolean values (in Objective-C, use `BOOL`)
@property (nonatomic, assign) BOOL loraInitWithoutApply;
@property (nonatomic, strong) NSArray *loraAdapters;
@property (nonatomic, strong) NSArray *controlVectors;
@ -257,6 +209,8 @@ typedef NS_ENUM(NSInteger, GGMLSchedPriority) {
@property (nonatomic, assign) BOOL ctxShift; // context shift on inifinite text generation
@property (nonatomic, assign) BOOL displayPrompt; // print prompt before generation
@property (nonatomic, assign) BOOL logging; // print logging
@end
#endif /* GPTParams_h */

View file

@ -5,11 +5,6 @@
#import "ggml.h"
#import "../../common/common.h"
@interface GGMLThreadpool()
- (ggml_threadpool *)threadpool;
@end
@interface GPTParams()

View file

@ -38,14 +38,14 @@ typedef int32_t LlamaToken;
index:(NSInteger) index
grammarFirst:(BOOL)grammarFirst;
// if accept_grammar is true, the token is accepted both by the sampling chain and the grammar
/// If accept_grammar is true, the token is accepted both by the sampling chain and the grammar
- (void)accept:(LlamaToken)token
acceptGrammar:(BOOL)acceptGrammar;
// get a string representation of the last accepted tokens
/// Get a string representation of the last accepted tokens
- (NSString *)previousString:(LlamaContext *)context n:(NSInteger)n;
// get the last accepted token
/// Get the last accepted token
- (LlamaToken)last;
- (void)reset;

View file

@ -5,15 +5,15 @@ typedef NSInteger LlamaSequenceId;
typedef NSInteger LlamaPosition;
typedef int32_t LlamaToken;
// Input data for llama_decode
// A llama_batch object can contain input about one or many sequences
// The provided arrays (i.e. token, embd, pos, etc.) must have size of n_tokens
//
// - token : the token ids of the input (used when embd is NULL)
// - embd : token embeddings (i.e. float vector of size n_embd) (used when token is NULL)
// - pos : the positions of the respective token in the sequence
// - seq_id : the sequence to which the respective token belongs
// - logits : if zero, the logits (and/or the embeddings) for the respective token will not be output
/// Input data for llama_decode
/// A llama_batch object can contain input about one or many sequences
/// The provided arrays (i.e. token, embd, pos, etc.) must have size of n_tokens
///
/// - token : the token ids of the input (used when embd is NULL)
/// - embd : token embeddings (i.e. float vector of size n_embd) (used when token is NULL)
/// - pos : the positions of the respective token in the sequence
/// - seq_id : the sequence to which the respective token belongs
/// - logits : if zero, the logits (and/or the embeddings) for the respective token will not be output
@interface LlamaBatch : NSObject
@property (nonatomic, assign) NSInteger nTokens;

View file

@ -2,6 +2,8 @@
#define LlamaObjC_h
#include <Foundation/Foundation.h>
#include <CPUParams.h>
#include <GGMLThreadpool.h>
#include <GPTParams.h>
#include <GPTSampler.h>
#include <llama.h>

View file

@ -14,10 +14,11 @@
@end
@interface LlamaSession : NSObject
NS_REFINED_FOR_SWIFT @interface LlamaSession : NSObject
@property (nonatomic, strong) LlamaModel *model;
@property (nonatomic, strong) LlamaContext *ctx;
@property (nonatomic, strong, readonly) NSString *lastOutput;
- (instancetype)initWithParams:(GPTParams *)params;
- (void)start:(BlockingLineQueue *)queue;

View file

@ -3,8 +3,14 @@
#import "LlamaSession.h"
@class GPTSampler;
@interface LlamaSession()
@property (atomic, strong) NSMutableString *mutableLastOutput;
@property (nonatomic, strong) GPTParams *params;
@property (nonatomic, strong) GPTSampler *smpl;
@end
#endif /* LlamaSession_Private_hpp */

View file

@ -24,8 +24,6 @@ public class SchemaConverter {
}
private func formatLiteral(_ literal: Any) -> String {
// let escaped = GRAMMAR_LITERAL_ESCAPES.reduce("\(literal)", {
// let regex = Regex("[\r\n\"]")
let escaped = GRAMMAR_LITERAL_ESCAPES.reduce("\(literal)") {
$0.replacingOccurrences(of: $1.key, with: $1.value)
}

View file

@ -57,6 +57,28 @@ public struct _JSONFunctionSchema: Codable {
self.enum = nil
}
public init(type: Int.Type, description: String?) {
self.type = "integer"
self.description = description
self.items = nil
self.enum = nil
}
public init(type: Double.Type, description: String?) {
self.type = "number"
self.description = description
self.items = nil
self.enum = nil
}
public init(type: Bool.Type, description: String?) {
self.type = "boolean"
self.description = description
self.items = nil
self.enum = nil
}
public init<T: CaseIterable>(type: T.Type, description: String?) where T: RawRepresentable,
T: StringProtocol {
self.type = "string"
@ -146,6 +168,14 @@ extension Double : JSONSchemaConvertible {
]
}
}
extension Bool : JSONSchemaConvertible {
public static var type: String { "boolean" }
public static var jsonSchema: [String: Any] {
[
"type": "boolean"
]
}
}
extension Date : JSONSchemaConvertible {
public static var type: String { "string" }

View file

@ -7,11 +7,63 @@ public protocol DynamicCallable: Sendable {
func dynamicallyCall(withKeywordArguments args: [String: Any]) async throws -> String
}
public enum AnyDecodable: Decodable {
case string(String)
case int(Int)
case double(Double)
case bool(Bool)
case null
// Add other cases as needed
// Initializers for each type
init(_ value: String) {
self = .string(value)
}
init(_ value: Int) {
self = .int(value)
}
init(_ value: Double) {
self = .double(value)
}
init(_ value: Bool) {
self = .bool(value)
}
init() {
self = .null
}
// Decodable conformance
public init(from decoder: Decoder) throws {
let container = try decoder.singleValueContainer()
if container.decodeNil() {
self = .null
} else if let intValue = try? container.decode(Int.self) {
self = .int(intValue)
} else if let doubleValue = try? container.decode(Double.self) {
self = .double(doubleValue)
} else if let boolValue = try? container.decode(Bool.self) {
self = .bool(boolValue)
} else if let stringValue = try? container.decode(String.self) {
self = .string(stringValue)
} else {
let context = DecodingError.Context(
codingPath: decoder.codingPath,
debugDescription: "Cannot decode AnyDecodable"
)
throw DecodingError.typeMismatch(AnyDecodable.self, context)
}
}
}
struct ToolCall: Decodable {
let id: Int
let name: String
let arguments: [String: String]
let arguments: [String: AnyDecodable]
}
struct ToolResponse<T: Encodable>: Encodable {
@ -23,10 +75,13 @@ struct ToolResponse<T: Encodable>: Encodable {
/// Standard chat session for a given LLM.
public actor LlamaChatSession {
private let queue = BlockingLineQueue()
private let session: LlamaObjC.LlamaSession
private let session: __LlamaSession
/// Initialize the session
/// - parameter params: common parameters to initialize the session
/// - parameter flush: whether or not to flush the initial prompt, reading initial output
public init(params: GPTParams, flush: Bool = true) async throws {
session = LlamaObjC.LlamaSession(params: params)
self.session = __LlamaSession(params: params)
Task.detached { [session, queue] in
session.start(queue)
}
@ -36,7 +91,36 @@ public actor LlamaChatSession {
_ = queue.outputLine()
}
public func chat(message: String) async -> String {
/// Create a new inference stream for a given message
/// - parameter message: The message to receive an inference for.
/// - returns: A stream of output from the LLM.
public func inferenceStream(message: String) async -> AsyncStream<String> {
queue.addInputLine(message)
var observationToken: NSKeyValueObservation?
return AsyncStream { stream in
observationToken = self.session.observe(\.lastOutput, options: [.new, .old]) { session, change in
guard let newValue = change.newValue,
let oldValue = change.oldValue else {
return stream.finish()
}
var delta = ""
for change in newValue!.difference(from: oldValue!) {
switch change {
case .remove(_, _, _):
return stream.finish()
case .insert(_, let element, _):
delta.append(element)
}
}
stream.yield(delta)
}
stream.onTermination = { [observationToken] _ in
observationToken?.invalidate()
}
}
}
public func infer(message: String) async -> String {
queue.addInputLine(message)
return queue.outputLine()
}
@ -46,15 +130,15 @@ public actor LlamaChatSession {
public actor LlamaSession<T: JSONSchemaConvertible> {
private let session: LlamaChatSession
public init(params: GPTParams) async throws {
public init(params: GPTParams, flush: Bool = true) async throws {
let converter = SchemaConverter(propOrder: [])
_ = converter.visit(schema: T.jsonSchema, name: nil)
params.samplerParams.grammar = converter.formatGrammar()
session = try await LlamaChatSession(params: params)
session = try await LlamaChatSession(params: params, flush: flush)
}
public func chat(message: String) async throws -> T {
let output = await session.chat(message: message).data(using: .utf8)!
let output = await session.infer(message: message).data(using: .utf8)!
return try JSONDecoder().decode(T.self, from: output)
}
}
@ -117,11 +201,14 @@ public actor LlamaToolSession {
self.tools["getIpAddress"] = (GetIpAddress(), ipFnSchema)
let encoded = try JSONEncoder().encode(self.tools.values.map(\.1))
let prompt = """
\(params.prompt ?? "")
You are a function calling AI model. You are provided with function signatures within <tools></tools> XML tags. You may call one or more functions to assist with the user query. Don't make assumptions about what values to plug into functions. For each function call return a json object with function name and arguments within <tool_call></tool_call> XML tags as follows:
<tool_call>
{"name": <function-name>,"arguments": <args-dict>}
</tool_call>
Feel free to chain tool calls, e.g., if you need the user's location to find points of interest near them, fetch the user's location first.
The first call you will be asked to warm up is to get the user's IP address. Here are the available tools:
<tools> \(String(data: encoded, encoding: .utf8)!) </tools><|eot_id|>
"""
@ -131,7 +218,7 @@ public actor LlamaToolSession {
params.inputPrefix = "<|start_header_id|>user<|end_header_id|>";
params.inputSuffix = "<|eot_id|><|start_header_id|>assistant<|end_header_id|>";
session = try await LlamaChatSession(params: params, flush: false)
let fn = await session.chat(message: "What is my IP address?")
let fn = await session.infer(message: "What is my IP address?")
let toolCall = try JSONDecoder().decode(ToolCall.self, from: fn.data(using: .utf8)!)
guard let tool = self.tools[toolCall.name] else {
fatalError()
@ -139,7 +226,7 @@ public actor LlamaToolSession {
let resp = try await tool.0.dynamicallyCall(withKeywordArguments: toolCall.arguments)
print(resp)
let output = await session.chat(message: """
let output = await session.infer(message: """
<tool_response>
{"id": \(toolCall.id), result: \(resp)}
</tool_response>
@ -147,39 +234,58 @@ public actor LlamaToolSession {
print(output)
}
public func chat(message: String) async throws -> String {
var nxt = await session.chat(message: message)
let fn = nxt
// try to see if the output is a function call
private func callTool(_ call: String) async -> String? {
var nxt: String?
do {
let toolCall = try JSONDecoder().decode(ToolCall.self, from: fn.data(using: .utf8)!)
let toolCall = try JSONDecoder().decode(ToolCall.self, from: call.data(using: .utf8)!)
guard let tool = tools[toolCall.name] else {
fatalError()
}
// TODO: tool call decode is allowed to fail but the code below is not
let callable = tool.0
let resp = try await callable.dynamicallyCall(withKeywordArguments: toolCall.arguments)
print("tool response: \(resp)")
nxt = await session.chat(message: """
<tool_response>
{"id": \(toolCall.id), result: \(resp)}
</tool_response>
""")
print(nxt)
} catch {
print(error)
}
do {
let response = try await callable.dynamicallyCall(withKeywordArguments: toolCall.arguments)
print("tool response: \(response)")
nxt = await session.infer(message: """
<tool_response>
{"id": \(toolCall.id), result: \(response)}
</tool_response>
""")
// TODO: If this decodes correctly, we should tail this into this method
// TODO: so that we do not decode twice
if let _ = try? JSONDecoder().decode(ToolCall.self, from: nxt!.data(using: .utf8)!) {
return await callTool(nxt!)
}
} catch {
nxt = await session.infer(message: """
<tool_response>
{"id": \(toolCall.id), result: "The tool call has unfortunately failed."}
</tool_response>
""")
}
print(nxt ?? "nil")
} catch {}
return nxt
}
public func infer(message: String) async throws -> String {
let output = await session.infer(message: message)
guard let output = await callTool(output) else {
return output
}
return output
}
}
public protocol LlamaActor: Actor {
static var tools: [String: (DynamicCallable, _JSONFunctionSchema)] { get }
var session: LlamaToolSession { get }
static func tools(_ self: Self) -> [String: (DynamicCallable, _JSONFunctionSchema)]
var session: LlamaToolSession! { get }
}
public extension LlamaActor {
func chat(_ message: String) async throws -> String {
try await session.chat(message: message)
try await session.infer(message: message)
}
}

View file

@ -16,12 +16,14 @@ struct ToolMacro: BodyMacro {
struct LlamaActorMacro: ExtensionMacro, MemberMacro {
static func expansion(of node: AttributeSyntax, providingMembersOf declaration: some DeclGroupSyntax, conformingTo protocols: [TypeSyntax], in context: some MacroExpansionContext) throws -> [DeclSyntax] {
[
return [
"""
let session: LlamaToolSession
var session: LlamaToolSession!
public init(params: GPTParams) async throws {
self.session = try await LlamaToolSession(params: params, tools: Self.tools)
\(raw: declaration.inheritanceClause != nil ? "self.init()" : "")
let tools = Self.tools(self)
self.session = try await LlamaToolSession(params: params, tools: tools)
}
"""
]
@ -41,6 +43,7 @@ struct LlamaActorMacro: ExtensionMacro, MemberMacro {
callableString: String,
callableName: String)
] = []
let typeName = type.as(IdentifierTypeSyntax.self)!.name.text
for member in declaration.memberBlock.members {
let comments = member.leadingTrivia.filter { $0.isComment }
guard let member = member.decl.as(FunctionDeclSyntax.self) else {
@ -72,6 +75,11 @@ struct LlamaActorMacro: ExtensionMacro, MemberMacro {
let callableName = context.makeUniqueName(name.text)
let callableString = """
@dynamicCallable struct \(callableName.text): DynamicCallable {
private weak var llamaActor: \(typeName)?
init(_ llamaActor: \(typeName)) {
self.llamaActor = llamaActor
}
@discardableResult
func dynamicallyCall(withKeywordArguments args: [String: Any]) async throws -> String {
\(parameters.map {
@ -79,11 +87,15 @@ struct LlamaActorMacro: ExtensionMacro, MemberMacro {
}.joined(separator: "\n"))
for (key, value) in args {
\(parameters.map {
"if key == \"\($0.name)\" { \($0.name) = value as! \($0.type) }"
"""
if key == "\($0.name)", let value = value as? AnyDecodable, case let .\($0.type.lowercased())(v) = value {
\($0.name) = v
}
"""
}.joined(separator: "\n"))
}
let returnValue = try await \(name.text)(\(parameters.map { "\($0.name): \($0.name)" }.joined(separator: ",")))
let returnValue = try await self.llamaActor!.\(name.text)(\(parameters.map { "\($0.name): \($0.name)" }.joined(separator: ",")))
let jsonValue = try JSONEncoder().encode(returnValue)
return String(data: jsonValue, encoding: .utf8)!
}
@ -105,10 +117,10 @@ struct LlamaActorMacro: ExtensionMacro, MemberMacro {
$0.callableString
}.joined(separator: "\n"))
static var tools: [String: (DynamicCallable, _JSONFunctionSchema)] {
static func tools(_ self: \(raw: typeName)) -> [String: (DynamicCallable, _JSONFunctionSchema)] {
[\(raw: tools.map { tool in
"""
"\(tool.name)": (\(tool.callableName)(), _JSONFunctionSchema(name: "\(tool.name)", description: "\(tool.description)", parameters: _JSONFunctionSchema.Parameters(properties: \(tool.parameters.count == 0 ? "[:]" : "[" + tool.parameters.map { parameter in
"\(tool.name)": (\(tool.callableName)(self), _JSONFunctionSchema(name: "\(tool.name)", description: "\(tool.description)", parameters: _JSONFunctionSchema.Parameters(properties: \(tool.parameters.count == 0 ? "[:]" : "[" + tool.parameters.map { parameter in
"""
"\(parameter.name)": _JSONFunctionSchema.Property(type: \(parameter.type).self, description: "\(parameter.description)"),
"""

View file

@ -32,7 +32,7 @@ func downloadFile() async throws -> String {
@llamaActor actor MyLlama {
/// Get the current date.
@Tool public static func getCurrentDate() -> String {
@Tool public func getCurrentDate() -> String {
Date.now.formatted(date: .long, time: .complete)
}
}

View file

@ -0,0 +1,370 @@
import Foundation
import XCTest
import LlamaKit // Replace with your module name
final class GPTParamsTests: XCTestCase {
func testPropertyAssignmentsAndCopy() throws {
// Create an instance of GPTParams
let originalParams = GPTParams()
// Assign values to all properties
originalParams.nPredict = 10
originalParams.nCtx = 20
originalParams.nBatch = 30
originalParams.nUBatch = 40
originalParams.nKeep = 50
originalParams.nDraft = 60
originalParams.nChunks = 70
originalParams.nParallel = 80
originalParams.nSequences = 90
originalParams.pSplit = 0.5
originalParams.nGpuLayers = 100
originalParams.nGpuLayersDraft = 110
originalParams.mainGpu = 120
originalParams.tensorSplit = [0.1, 0.2, 0.3]
originalParams.grpAttnN = 130
originalParams.grpAttnW = 140
originalParams.nPrint = 150
originalParams.ropeFreqBase = 0.6
originalParams.ropeFreqScale = 0.7
originalParams.yarnExtFactor = 0.8
originalParams.yarnAttnFactor = 0.9
originalParams.yarnBetaFast = 1.0
originalParams.yarnBetaSlow = 1.1
originalParams.yarnOrigCtx = 160
originalParams.defragThold = 1.2
// Initialize CPUParams instances if needed
originalParams.cpuParams = CPUParams()
originalParams.cpuParamsBatch = CPUParams()
originalParams.draftCpuParams = CPUParams()
originalParams.draftCpuParamsBatch = CPUParams()
// Assign blocks and user data
originalParams.cbEval = { userData in
// Callback implementation
}
originalParams.cbEvalUserData = nil
// Assign enum values (assuming NSInteger maps to Int)
originalParams.numaStrategy = 1
originalParams.splitMode = 2
originalParams.ropeScalingType = 3
originalParams.poolingType = 4
originalParams.attentionType = 5
// Assign sampler parameters
originalParams.samplerParams = GPTSamplerParams()
// Assign string properties
originalParams.modelPath = "path/to/model"
originalParams.modelDraft = "model_draft"
originalParams.modelAlias = "alias"
originalParams.modelURL = "http://model.url"
originalParams.hfToken = "token"
originalParams.hfRepo = "repo"
originalParams.hfFile = "file"
originalParams.prompt = "prompt"
originalParams.promptFile = "prompt.txt"
originalParams.pathPromptCache = "cache/path"
originalParams.inputPrefix = "prefix"
originalParams.inputSuffix = "suffix"
originalParams.logdir = "log/dir"
originalParams.lookupCacheStatic = "static/cache"
originalParams.lookupCacheDynamic = "dynamic/cache"
originalParams.logitsFile = "logits.txt"
originalParams.rpcServers = "servers"
// Assign array properties
originalParams.inputFiles = ["input1.txt", "input2.txt"]
originalParams.antiPrompts = ["anti1", "anti2"]
originalParams.kvOverrides = ["override1", "override2"]
originalParams.loraAdapters = ["adapter1", "adapter2"]
originalParams.controlVectors = ["control1", "control2"]
// Assign boolean and control properties
originalParams.loraInitWithoutApply = true
originalParams.verbosity = 1
originalParams.controlVectorLayerStart = 2
originalParams.controlVectorLayerEnd = 3
originalParams.pplStride = 4
originalParams.pplOutputType = 5
originalParams.hellaswag = true
originalParams.hellaswagTasks = 10
originalParams.winogrande = false
originalParams.winograndeTasks = 20
originalParams.multipleChoice = true
originalParams.multipleChoiceTasks = 30
originalParams.klDivergence = false
originalParams.usage = true
originalParams.useColor = false
originalParams.special = true
originalParams.interactive = false
originalParams.interactiveFirst = true
originalParams.conversation = false
originalParams.promptCacheAll = true
originalParams.promptCacheRO = false
originalParams.escapeSequences = true
originalParams.multilineInput = false
originalParams.simpleIO = true
originalParams.continuousBatching = false
originalParams.flashAttention = true
originalParams.noPerformanceMetrics = false
originalParams.contextShift = true
// Server and I/O settings
originalParams.port = 8080
originalParams.timeoutRead = 60
originalParams.timeoutWrite = 30
originalParams.httpThreads = 4
originalParams.hostname = "localhost"
originalParams.publicPath = "/public"
originalParams.chatTemplate = "template"
originalParams.systemPrompt = "system prompt"
originalParams.enableChatTemplate = true
originalParams.apiKeys = ["key1", "key2"]
originalParams.sslFileKey = "key.pem"
originalParams.sslFileCert = "cert.pem"
originalParams.endpointSlots = true
originalParams.endpointMetrics = false
originalParams.logJSON = true
originalParams.slotSavePath = "/slots"
originalParams.slotPromptSimilarity = 0.75
// Batched-bench params
originalParams.isPPShared = true
originalParams.nPP = [1, 2]
originalParams.nTG = [3, 4]
originalParams.nPL = [5, 6]
// Retrieval params
originalParams.contextFiles = ["context1.txt", "context2.txt"]
originalParams.chunkSize = 1024
originalParams.chunkSeparator = "\n"
// Passkey params
originalParams.nJunk = 7
originalParams.iPos = 8
// Imatrix params
originalParams.outFile = "output.txt"
originalParams.nOutFreq = 100
originalParams.nSaveFreq = 200
originalParams.iChunk = 9
originalParams.processOutput = true
originalParams.computePPL = false
// Cvector-generator params
originalParams.nPCABatch = 10
originalParams.nPCAIterations = 11
originalParams.cvectorDimreMethod = 12
originalParams.cvectorOutfile = "cvector.out"
originalParams.cvectorPositiveFile = "positive.txt"
originalParams.cvectorNegativeFile = "negative.txt"
// Additional properties
originalParams.spmInfill = true
originalParams.loraOutfile = "lora.out"
originalParams.embedding = false
originalParams.verbosePrompt = true
originalParams.batchedBenchOutputJSONL = false
originalParams.inputPrefixBOS = true
originalParams.ctxShift = false
originalParams.displayPrompt = true
originalParams.logging = false
// Verify that properties are assigned correctly
XCTAssertEqual(originalParams.nPredict, 10)
XCTAssertEqual(originalParams.nCtx, 20)
XCTAssertEqual(originalParams.nBatch, 30)
XCTAssertEqual(originalParams.nUBatch, 40)
XCTAssertEqual(originalParams.nKeep, 50)
XCTAssertEqual(originalParams.nDraft, 60)
XCTAssertEqual(originalParams.nChunks, 70)
XCTAssertEqual(originalParams.nParallel, 80)
XCTAssertEqual(originalParams.nSequences, 90)
XCTAssertEqual(originalParams.pSplit, 0.5)
XCTAssertEqual(originalParams.nGpuLayers, 100)
XCTAssertEqual(originalParams.nGpuLayersDraft, 110)
XCTAssertEqual(originalParams.mainGpu, 120)
XCTAssertEqual(originalParams.tensorSplit[0..<3].map(\.floatValue),
[0.1, 0.2, 0.3])
XCTAssertEqual(originalParams.grpAttnN, 130)
XCTAssertEqual(originalParams.grpAttnW, 140)
XCTAssertEqual(originalParams.nPrint, 150)
XCTAssertEqual(originalParams.ropeFreqBase, 0.6)
XCTAssertEqual(originalParams.ropeFreqScale, 0.7)
XCTAssertEqual(originalParams.yarnExtFactor, 0.8)
XCTAssertEqual(originalParams.yarnAttnFactor, 0.9)
XCTAssertEqual(originalParams.yarnBetaFast, 1.0)
XCTAssertEqual(originalParams.yarnBetaSlow, 1.1)
XCTAssertEqual(originalParams.yarnOrigCtx, 160)
XCTAssertEqual(originalParams.defragThold, 1.2)
// Verify enums
XCTAssertEqual(originalParams.numaStrategy, 1)
XCTAssertEqual(originalParams.splitMode, 2)
XCTAssertEqual(originalParams.ropeScalingType, 3)
XCTAssertEqual(originalParams.poolingType, 4)
XCTAssertEqual(originalParams.attentionType, 5)
// Verify string properties
XCTAssertEqual(originalParams.modelPath, "path/to/model")
XCTAssertEqual(originalParams.modelDraft, "model_draft")
XCTAssertEqual(originalParams.modelAlias, "alias")
XCTAssertEqual(originalParams.modelURL, "http://model.url")
XCTAssertEqual(originalParams.hfToken, "token")
XCTAssertEqual(originalParams.hfRepo, "repo")
XCTAssertEqual(originalParams.hfFile, "file")
XCTAssertEqual(originalParams.prompt, "prompt")
XCTAssertEqual(originalParams.promptFile, "prompt.txt")
XCTAssertEqual(originalParams.pathPromptCache, "cache/path")
XCTAssertEqual(originalParams.inputPrefix, "prefix")
XCTAssertEqual(originalParams.inputSuffix, "suffix")
XCTAssertEqual(originalParams.logdir, "log/dir")
XCTAssertEqual(originalParams.lookupCacheStatic, "static/cache")
XCTAssertEqual(originalParams.lookupCacheDynamic, "dynamic/cache")
XCTAssertEqual(originalParams.logitsFile, "logits.txt")
XCTAssertEqual(originalParams.rpcServers, "servers")
// Verify array properties
XCTAssertEqual(originalParams.inputFiles, ["input1.txt", "input2.txt"])
XCTAssertEqual(originalParams.antiPrompts, ["anti1", "anti2"])
XCTAssertEqual(originalParams.kvOverrides as? [String], ["override1", "override2"])
// XCTAssertEqual(originalParams.loraAdapters, ["adapter1", "adapter2"])
// XCTAssertEqual(originalParams.controlVectors, ["control1", "control2"])
// Verify boolean and control properties
XCTAssertTrue(originalParams.loraInitWithoutApply)
XCTAssertEqual(originalParams.verbosity, 1)
XCTAssertEqual(originalParams.controlVectorLayerStart, 2)
XCTAssertEqual(originalParams.controlVectorLayerEnd, 3)
XCTAssertEqual(originalParams.pplStride, 4)
XCTAssertEqual(originalParams.pplOutputType, 5)
XCTAssertTrue(originalParams.hellaswag)
XCTAssertEqual(originalParams.hellaswagTasks, 10)
XCTAssertFalse(originalParams.winogrande)
XCTAssertEqual(originalParams.winograndeTasks, 20)
XCTAssertTrue(originalParams.multipleChoice)
XCTAssertEqual(originalParams.multipleChoiceTasks, 30)
XCTAssertFalse(originalParams.klDivergence)
XCTAssertTrue(originalParams.usage)
XCTAssertFalse(originalParams.useColor)
XCTAssertTrue(originalParams.special)
XCTAssertFalse(originalParams.interactive)
XCTAssertTrue(originalParams.interactiveFirst)
XCTAssertFalse(originalParams.conversation)
XCTAssertTrue(originalParams.promptCacheAll)
XCTAssertFalse(originalParams.promptCacheRO)
XCTAssertTrue(originalParams.escapeSequences)
XCTAssertFalse(originalParams.multilineInput)
XCTAssertTrue(originalParams.simpleIO)
XCTAssertFalse(originalParams.continuousBatching)
XCTAssertTrue(originalParams.flashAttention)
XCTAssertFalse(originalParams.noPerformanceMetrics)
XCTAssertTrue(originalParams.contextShift)
// Verify server and I/O settings
XCTAssertEqual(originalParams.port, 8080)
XCTAssertEqual(originalParams.timeoutRead, 60)
XCTAssertEqual(originalParams.timeoutWrite, 30)
XCTAssertEqual(originalParams.httpThreads, 4)
XCTAssertEqual(originalParams.hostname, "localhost")
XCTAssertEqual(originalParams.publicPath, "/public")
XCTAssertEqual(originalParams.chatTemplate, "template")
XCTAssertEqual(originalParams.systemPrompt, "system prompt")
XCTAssertTrue(originalParams.enableChatTemplate)
XCTAssertEqual(originalParams.apiKeys, ["key1", "key2"])
XCTAssertEqual(originalParams.sslFileKey, "key.pem")
XCTAssertEqual(originalParams.sslFileCert, "cert.pem")
XCTAssertTrue(originalParams.endpointSlots)
XCTAssertFalse(originalParams.endpointMetrics)
XCTAssertTrue(originalParams.logJSON)
XCTAssertEqual(originalParams.slotSavePath, "/slots")
XCTAssertEqual(originalParams.slotPromptSimilarity, 0.75)
// Verify batched-bench params
XCTAssertTrue(originalParams.isPPShared)
XCTAssertEqual(originalParams.nPP, [1, 2])
XCTAssertEqual(originalParams.nTG, [3, 4])
XCTAssertEqual(originalParams.nPL, [5, 6])
// Verify retrieval params
XCTAssertEqual(originalParams.contextFiles, ["context1.txt", "context2.txt"])
XCTAssertEqual(originalParams.chunkSize, 1024)
XCTAssertEqual(originalParams.chunkSeparator, "\n")
// Verify passkey params
XCTAssertEqual(originalParams.nJunk, 7)
XCTAssertEqual(originalParams.iPos, 8)
// Verify imatrix params
XCTAssertEqual(originalParams.outFile, "output.txt")
XCTAssertEqual(originalParams.nOutFreq, 100)
XCTAssertEqual(originalParams.nSaveFreq, 200)
XCTAssertEqual(originalParams.iChunk, 9)
XCTAssertTrue(originalParams.processOutput)
XCTAssertFalse(originalParams.computePPL)
// Verify cvector-generator params
XCTAssertEqual(originalParams.nPCABatch, 10)
XCTAssertEqual(originalParams.nPCAIterations, 11)
XCTAssertEqual(originalParams.cvectorDimreMethod, 12)
XCTAssertEqual(originalParams.cvectorOutfile, "cvector.out")
XCTAssertEqual(originalParams.cvectorPositiveFile, "positive.txt")
XCTAssertEqual(originalParams.cvectorNegativeFile, "negative.txt")
// Verify additional properties
XCTAssertTrue(originalParams.spmInfill)
XCTAssertEqual(originalParams.loraOutfile, "lora.out")
XCTAssertFalse(originalParams.embedding)
XCTAssertTrue(originalParams.verbosePrompt)
XCTAssertFalse(originalParams.batchedBenchOutputJSONL)
XCTAssertTrue(originalParams.inputPrefixBOS)
XCTAssertFalse(originalParams.ctxShift)
XCTAssertTrue(originalParams.displayPrompt)
XCTAssertFalse(originalParams.logging)
// Test the copy function
guard let copiedParams = originalParams.copy() as? GPTParams else {
XCTFail("Copy function did not return a GPTParams instance.")
return
}
// Verify that the copied properties match the original
XCTAssertEqual(copiedParams.nPredict, originalParams.nPredict)
XCTAssertEqual(copiedParams.nCtx, originalParams.nCtx)
XCTAssertEqual(copiedParams.nBatch, originalParams.nBatch)
XCTAssertEqual(copiedParams.nUBatch, originalParams.nUBatch)
XCTAssertEqual(copiedParams.nKeep, originalParams.nKeep)
XCTAssertEqual(copiedParams.nDraft, originalParams.nDraft)
XCTAssertEqual(copiedParams.nChunks, originalParams.nChunks)
XCTAssertEqual(copiedParams.nParallel, originalParams.nParallel)
XCTAssertEqual(copiedParams.nSequences, originalParams.nSequences)
XCTAssertEqual(copiedParams.pSplit, originalParams.pSplit)
XCTAssertEqual(copiedParams.nGpuLayers, originalParams.nGpuLayers)
XCTAssertEqual(copiedParams.nGpuLayersDraft, originalParams.nGpuLayersDraft)
XCTAssertEqual(copiedParams.mainGpu, originalParams.mainGpu)
XCTAssertEqual(copiedParams.tensorSplit, originalParams.tensorSplit)
XCTAssertEqual(copiedParams.grpAttnN, originalParams.grpAttnN)
XCTAssertEqual(copiedParams.grpAttnW, originalParams.grpAttnW)
XCTAssertEqual(copiedParams.nPrint, originalParams.nPrint)
XCTAssertEqual(copiedParams.ropeFreqBase, originalParams.ropeFreqBase)
XCTAssertEqual(copiedParams.ropeFreqScale, originalParams.ropeFreqScale)
XCTAssertEqual(copiedParams.yarnExtFactor, originalParams.yarnExtFactor)
XCTAssertEqual(copiedParams.yarnAttnFactor, originalParams.yarnAttnFactor)
XCTAssertEqual(copiedParams.yarnBetaFast, originalParams.yarnBetaFast)
XCTAssertEqual(copiedParams.yarnBetaSlow, originalParams.yarnBetaSlow)
XCTAssertEqual(copiedParams.yarnOrigCtx, originalParams.yarnOrigCtx)
XCTAssertEqual(copiedParams.defragThold, originalParams.defragThold)
XCTAssertEqual(copiedParams.modelPath, originalParams.modelPath)
XCTAssertEqual(copiedParams.apiKeys, originalParams.apiKeys)
// XCTAssertEqual(copiedParams.controlVectors, originalParams.controlVectors)
// Continue verifying all other properties...
// Verify that modifying the original does not affect the copy
originalParams.nPredict = 999
XCTAssertNotEqual(copiedParams.nPredict, originalParams.nPredict)
}
}

View file

@ -2,27 +2,28 @@ import Foundation
import Testing
@testable import LlamaKit
import JSONSchema
import OSLog
// MARK: LlamaGrammarSession Suite
@Suite("LlamaGrammarSession Suite")
struct LlamaGrammarSessionSuite {
@Suite("LlamaSession Suite")
struct LlamaSessionSuite {
@JSONSchema struct Trip {
let location: String
let startDate: TimeInterval
let durationInDays: Int
}
func downloadFile() async throws -> String {
func downloadFile(url: String, to path: String) async throws -> String {
let fm = FileManager.default
let tmpDir = fm.temporaryDirectory
let destinationURL = tmpDir.appending(path: "tinyllama.gguf")
let destinationURL = tmpDir.appending(path: path)
guard !fm.fileExists(atPath: destinationURL.path()) else {
return destinationURL.path()
}
print("Downloading TinyLlama, this may take a while...")
print("Downloading \(path), this may take a while...")
// Define the URL
guard let url = URL(string: "https://huggingface.co/TheBloke/TinyLlama-1.1B-Chat-v1.0-GGUF/resolve/main/tinyllama-1.1b-chat-v1.0.Q3_K_L.gguf?download=true") else {
guard let url = URL(string: url) else {
print("Invalid URL.")
throw URLError(.badURL)
}
@ -32,22 +33,22 @@ struct LlamaGrammarSessionSuite {
// Define the destination path in the documents directory
// Move the downloaded file to the destination
try fm.moveItem(at: tempURL, to: destinationURL)
print("File downloaded to: \(destinationURL.path())")
return destinationURL.path()
}
@Test func llamaGrammarSession() async throws {
func baseParams(url: String, to path: String) async throws -> GPTParams {
let params = GPTParams()
params.modelPath = try await downloadFile()
params.nPredict = 256
params.nCtx = 1024
params.cpuParams.nThreads = 4
params.cpuParamsBatch.nThreads = 4
params.modelPath = try await downloadFile(url: url, to: path)
params.nPredict = 512
params.nCtx = 4096
params.cpuParams.nThreads = 8
params.cpuParamsBatch.nThreads = 8
params.nBatch = 1024
params.nGpuLayers = 128
params.nGpuLayers = 1024
params.chatTemplate = """
<|system|>
{system_message}</s>
@ -55,6 +56,28 @@ struct LlamaGrammarSessionSuite {
{prompt}</s>
<|assistant|>
"""
params.interactive = true
return params
}
@Test func llamaInferenceSession() async throws {
let params = try await baseParams(url: "https://huggingface.co/TheBloke/TinyLlama-1.1B-Chat-v1.0-GGUF/resolve/main/tinyllama-1.1b-chat-v1.0.Q8_0.gguf?download=true", to: "tinyllama.gguf")
params.prompt = """
<|system|>
You are an AI assistant. Answer queries simply and concisely.</s>
"""
params.antiPrompts = ["</s>"]
params.inputPrefix = "<|user|>"
params.inputSuffix = "</s><|assistant|>"
params.interactive = true
let session = try await LlamaChatSession(params: params, flush: false)
for await msg in await session.inferenceStream(message: "How are you today?") {
print(msg, terminator: "")
}
}
@Test func llamaGrammarSession() async throws {
let params = try await baseParams(url: "https://huggingface.co/TheBloke/TinyLlama-1.1B-Chat-v1.0-GGUF/resolve/main/tinyllama-1.1b-chat-v1.0.Q8_0.gguf?download=true", to: "tinyllama.gguf")
params.prompt = """
You are a travel agent. The current date epoch \(Date.now.timeIntervalSince1970).
Responses should have the following fields:
@ -64,7 +87,6 @@ struct LlamaGrammarSessionSuite {
durationInDays: the duration of the trip in days
"""
params.interactive = true
let session = try await LlamaSession<Trip>(params: params)
await #expect(throws: Never.self) {
let trip = try await session.chat(message: "Please create a trip for me to New York City that starts two weeks from now. The duration of the trip MUST be 3 days long.")
@ -73,6 +95,25 @@ struct LlamaGrammarSessionSuite {
// TODO: so for now, we are just asserting the grammar works
}
}
@JSONSchema struct IsCorrect {
let isSpellingCorrect: Bool
}
@Test func llamaSimpleGrammarSession() async throws {
let params = try await baseParams(url: "https://huggingface.co/RichardErkhov/openfoodfacts_-_spellcheck-mistral-7b-gguf/resolve/main/spellcheck-mistral-7b.Q8_0.gguf?download=true",
to: "spellcheck_q8.gguf")
params.prompt = """
###You are a spell checker. I will provide you with the word 'strawberry'. If the spelling of the given word is correct, please respond {"isCorrect": true} else respond {"isCorrect": false}.\n
"""
let session = try await LlamaSession<IsCorrect>(params: params)
for _ in 0..<10 {
var output = try await session.chat(message: "###strawberry\n")
#expect(output.isSpellingCorrect)
output = try await session.chat(message: "###strawberrry\n")
#expect(!output.isSpellingCorrect)
}
}
}
import WeatherKit
@ -115,7 +156,7 @@ func downloadFile() async throws -> String {
/// Get the current weather in a given location.
/// - parameter location: The city and state, e.g. San Francisco, CA
/// - parameter unit: The unit of temperature
@Tool public static func getCurrentWeather(location: String, unit: String) async throws -> CurrentWeather {
@Tool public func getCurrentWeather(location: String, unit: String) async throws -> CurrentWeather {
let weather = try await WeatherService().weather(for: CLGeocoder().geocodeAddressString(location)[0].location!)
var temperature = weather.currentWeather.temperature
temperature.convert(to: .fahrenheit)
@ -134,7 +175,7 @@ func downloadFile() async throws -> String {
params.nBatch = 1024
params.nGpuLayers = 1024
let llama = try await MyLlama(params: params)
let currentWeather = try await MyLlama.getCurrentWeather(location: "San Francisco, CA", unit: "farenheit")
let currentWeather = try await llama.getCurrentWeather(location: "San Francisco, CA", unit: "farenheit")
let output = try await llama.chat("What's the weather (in farenheit) in San Francisco, CA?")
#expect(output.contains(String(format: "%d", Int(currentWeather.temperature))))
// #expect(output.contains(currentWeather.condition.rawValue))