diff --git a/Package.swift b/Package.swift index 47d131f4e..95339fb2f 100644 --- a/Package.swift +++ b/Package.swift @@ -32,7 +32,7 @@ var cSettings: [CSetting] = [ // We should consider add this in the future when we drop support for iOS 14 // (ref: ref: https://developer.apple.com/documentation/accelerate/1513264-cblas_sgemm?language=objc) .define("ACCELERATE_NEW_LAPACK"), - .define("ACCELERATE_LAPACK_ILP64") + .define("ACCELERATE_LAPACK_ILP64"), ] #if canImport(Darwin) @@ -80,7 +80,8 @@ let package = Package( "ggml" ], sources: cppSources, - publicHeadersPath: "spm-headers"), + publicHeadersPath: "spm-headers", + cSettings: cSettings), .target( name: "llama", dependencies: ["llama_cpp"], @@ -94,6 +95,8 @@ let package = Package( dependencies: ["llama"], path: "objc", sources: [ + "CPUParams.mm", + "GGMLThreadpool.mm", "GPTParams.mm", "GPTSampler.mm", "LlamaBatch.mm", diff --git a/objc/CPUParams.mm b/objc/CPUParams.mm new file mode 100644 index 000000000..56fee7b3f --- /dev/null +++ b/objc/CPUParams.mm @@ -0,0 +1,68 @@ +#import +#import "CPUParams_Private.hpp" +#import "GGMLThreadpool_Private.hpp" + +@implementation CPUParams + +- (instancetype)initWithParams:(cpu_params&)params +{ + self = [super init]; + if (self) { + self->params = ¶ms; + } + return self; +} + +- (NSInteger)nThreads { + return params->n_threads; +} + +- (void)setNThreads:(NSInteger)nThreads { + params->n_threads = nThreads; +} + +- (BOOL)maskValid { + return params->mask_valid; +} + +- (void)setMaskValid:(BOOL)maskValid { + params->mask_valid = maskValid; +} + +- (GGMLSchedPriority)priority { + return GGMLSchedPriority(params->priority); +} + +- (void)setPriority:(GGMLSchedPriority)priority { + params->priority = ggml_sched_priority(priority); +} + +- (BOOL)strictCPU { + return params->strict_cpu; +} + +- (void)setStrictCPU:(BOOL)strictCPU { + params->strict_cpu = strictCPU; +} + +- (NSUInteger)poll { + return params->poll; +} + +- (void)setPoll:(NSUInteger)poll { + params->poll = poll; +} + +- (BOOL)getCpuMaskAtIndex:(NSUInteger)index { + return params->cpumask[index]; +} + +- (void)setCpuMask:(BOOL)value atIndex:(NSUInteger)index { + params->cpumask[index] = value; +} + +- (GGMLThreadpoolParams *)ggmlThreadpoolParams { + return [[GGMLThreadpoolParams alloc] initWithParams:ggml_threadpool_params_from_cpu_params(*params)]; +} + +@end diff --git a/objc/GGMLThreadpool.mm b/objc/GGMLThreadpool.mm new file mode 100644 index 000000000..5895b161d --- /dev/null +++ b/objc/GGMLThreadpool.mm @@ -0,0 +1,52 @@ +#import +#import "GGMLThreadpool_Private.hpp" + +@implementation GGMLThreadpool + +- (instancetype)initWithThreadpool:(ggml_threadpool *)threadpool +{ + self = [super init]; + if (self) { + self->threadpool = threadpool; + } + return self; +} + +- (ggml_threadpool *)threadpool { + return self->threadpool; +} + +@end + +@implementation GGMLThreadpoolParams { + ggml_threadpool_params params; +} + +- (BOOL)getCpuMaskAtIndex:(NSUInteger)index { + abort(); +} + +- (void)setCpuMask:(BOOL)value atIndex:(NSUInteger)index { + abort(); +} + +- (instancetype)initWithParams:(ggml_threadpool_params&&)params +{ + self = [super init]; + if (self) { + self->params = params; + } + return self; +} + +- (BOOL)isEqual:(id)other { + GGMLThreadpoolParams *rhs = (GGMLThreadpoolParams *)other; + ggml_threadpool_params rhs_params = rhs->params; + return ggml_threadpool_params_match(¶ms, &rhs_params); +} + +- (GGMLThreadpool *)threadpool { + auto tp = ggml_threadpool_new(¶ms); + return [[GGMLThreadpool alloc] initWithThreadpool:tp]; +} +@end diff --git a/objc/GPTParams.mm b/objc/GPTParams.mm index f6b99f635..1ada5a577 100644 --- a/objc/GPTParams.mm +++ b/objc/GPTParams.mm @@ -1,127 +1,10 @@ #import #import "GPTParams_Private.hpp" +#import "CPUParams_Private.hpp" +#import "GPTSampler.h" #import "../common/common.h" #import "ggml.h" -@implementation GGMLThreadpool { - ggml_threadpool *threadpool; -} - -- (instancetype)initWithThreadpool:(ggml_threadpool *)threadpool -{ - self = [super init]; - if (self) { - self->threadpool = threadpool; - } - return self; -} - -- (ggml_threadpool *)threadpool { - return threadpool; -} - -@end - -@implementation GGMLThreadpoolParams { - ggml_threadpool_params params; -} - -- (BOOL)getCpuMaskAtIndex:(NSUInteger)index { - abort(); -} - -- (void)setCpuMask:(BOOL)value atIndex:(NSUInteger)index { - abort(); -} - -- (instancetype)initWithParams:(ggml_threadpool_params&&)params -{ - self = [super init]; - if (self) { - self->params = params; - } - return self; -} - -- (BOOL)isEqual:(id)other { - GGMLThreadpoolParams *rhs = (GGMLThreadpoolParams *)other; - ggml_threadpool_params rhs_params = rhs->params; - return ggml_threadpool_params_match(¶ms, &rhs_params); -} - -- (GGMLThreadpool *)threadpool { - auto tp = ggml_threadpool_new(¶ms); - return [[GGMLThreadpool alloc] initWithThreadpool:tp]; -} -@end - -@implementation CPUParams { - cpu_params *params; -} - -- (instancetype)initWithParams:(cpu_params&)params; -{ - self = [super init]; - if (self) { - self->params = ¶ms; - } - return self; -} - -- (int)nThreads { - return params->n_threads; -} - -- (void)setNThreads:(int)nThreads { - params->n_threads = nThreads; -} - -- (BOOL)maskValid { - return params->mask_valid; -} - -- (void)setMaskValid:(BOOL)maskValid { - params->mask_valid = maskValid; -} - -- (GGMLSchedPriority)priority { - return GGMLSchedPriority(params->priority); -} - -- (void)setPriority:(GGMLSchedPriority)priority { - params->priority = ggml_sched_priority(priority); -} - -- (BOOL)strictCPU { - return params->strict_cpu; -} - -- (void)setStrictCPU:(BOOL)strictCPU { - params->strict_cpu = strictCPU; -} - -- (uint32_t)poll { - return params->poll; -} - -- (void)setPoll:(uint32_t)poll { - params->poll = poll; -} - -- (BOOL)getCpuMaskAtIndex:(NSUInteger)index { - return params->cpumask[index]; -} - -- (void)setCpuMask:(BOOL)value atIndex:(NSUInteger)index { - params->cpumask[index] = value; -} - -- (GGMLThreadpoolParams *)ggmlThreadpoolParams { - return [[GGMLThreadpoolParams alloc] initWithParams:ggml_threadpool_params_from_cpu_params(*params)]; -} - -@end - @implementation GPTSamplerParams { common_sampler_params *gpt_sampler_params; } @@ -415,6 +298,42 @@ return antiprompts; } +- (void)setAntiPrompts:(NSArray *)antiPrompts { + gpt_params.antiprompt.clear(); + for (NSString *antiprompt in antiPrompts) { + gpt_params.antiprompt.push_back([antiprompt cStringUsingEncoding:NSUTF8StringEncoding]); + } +} + +- (NSArray *)apiKeys { + auto apiKeys = [[NSMutableArray alloc] init]; + for (auto& apiKey : gpt_params.api_keys) { + [apiKeys addObject:[NSString stringWithCString:apiKey.c_str() encoding:NSUTF8StringEncoding]]; + } + return apiKeys; +} + +- (void)setApiKeys:(NSArray *)apiKeys { + gpt_params.api_keys.clear(); + for (NSString *apiKey in apiKeys) { + gpt_params.api_keys.push_back([apiKey cStringUsingEncoding:NSUTF8StringEncoding]); + } +} + +- (NSArray *)tensorSplit { + auto tensorSplit = [[NSMutableArray alloc] init]; + for (auto& tensor : gpt_params.tensor_split) { + [tensorSplit addObject:[[NSNumber alloc] initWithFloat:tensor]]; + } + return tensorSplit; +} + +- (void)setTensorSplit:(NSArray *)tensorSplit { + for (size_t i = 0; i < [tensorSplit count]; i++) { + gpt_params.tensor_split[i] = [tensorSplit[i] floatValue]; + } +} + - (common_params&)params { return gpt_params; } @@ -716,4 +635,29 @@ gpt_params.input_suffix = [inputSuffix cStringUsingEncoding:NSUTF8StringEncoding]; } +- (BOOL)interactive { + return gpt_params.interactive; +} + +- (void)setInteractive:(BOOL)interactive { + gpt_params.interactive = interactive; +} + +- (BOOL)interactiveFirst { + return gpt_params.interactive_first; +} + +- (void)setInteractiveFirst:(BOOL)interactiveFirst { + gpt_params.interactive_first = interactiveFirst; +} +- (id)copyWithZone:(NSZone *)zone { + GPTParams *copy = [[[self class] allocWithZone:zone] init]; + + if (copy) { + copy->gpt_params = gpt_params; + } + + return copy; +} + @end diff --git a/objc/LlamaContext.mm b/objc/LlamaContext.mm index 50ed8a3a3..b07cfd33d 100644 --- a/objc/LlamaContext.mm +++ b/objc/LlamaContext.mm @@ -1,5 +1,6 @@ #import #import "LlamaContext_Private.hpp" +#import "GGMLThreadpool_Private.hpp" #import "GPTParams_Private.hpp" #import "LlamaModel_Private.hpp" #import "LlamaBatch_Private.hpp" @@ -17,6 +18,12 @@ return self; } +- (void)dealloc +{ + [super dealloc]; + llama_free(ctx); +} + - (void)attachThreadpool:(GGMLThreadpool *)threadpool threadpoolBatch:(GGMLThreadpool *)threadpoolBatch { llama_attach_threadpool(ctx, [threadpool threadpool], [threadpoolBatch threadpool]); diff --git a/objc/LlamaModel.mm b/objc/LlamaModel.mm index 38d2015f5..1ee4cb520 100644 --- a/objc/LlamaModel.mm +++ b/objc/LlamaModel.mm @@ -22,6 +22,12 @@ return self; } +- (void)dealloc +{ + [super dealloc]; + llama_free_model(model); +} + - (LlamaContext *)context:(LlamaContextParams *)params { return nil; } diff --git a/objc/LlamaSession.mm b/objc/LlamaSession.mm index 60819db4c..bb29f1e2e 100644 --- a/objc/LlamaSession.mm +++ b/objc/LlamaSession.mm @@ -3,10 +3,12 @@ #import "../../common/common.h" #import "LlamaModel_Private.hpp" #import "LlamaContext_Private.hpp" +#import "CPUParams_Private.hpp" #import "GPTSampler.h" #import #import "ggml.h" #import "GPTParams_Private.hpp" +#import "GGMLThreadpool_Private.hpp" #import "LlamaBatch_Private.hpp" @implementation BlockingLineQueue { @@ -75,8 +77,7 @@ @implementation LlamaSession { std::vector embd_inp; std::vector chat_msgs; - GPTParams *params; - GPTSampler *smpl; + BOOL isInteracting; bool is_antiprompt; @@ -105,13 +106,14 @@ std::vector> antiprompt_ids; BOOL need_insert_eot; int n_ctx; + os_log_t os_log_inst; } - (NSString *)chat_add_and_format:(std::vector &) chat_msgs role:(const std::string &) role content:(const std::string &) content { common_chat_msg new_msg{role, content}; - auto formatted = common_chat_format_single([self.model cModel], [params params].chat_template, chat_msgs, new_msg, role == "user"); + auto formatted = common_chat_format_single([self.model cModel], [_params params].chat_template, chat_msgs, new_msg, role == "user"); chat_msgs.push_back({role, content}); - os_log_debug(OS_LOG_DEFAULT, "formatted: '%s'\n", formatted.c_str()); + os_log_debug(os_log_inst, "formatted: '%s'\n", formatted.c_str()); return [NSString stringWithCString:formatted.c_str() encoding:NSUTF8StringEncoding]; } @@ -131,22 +133,24 @@ static BOOL file_is_empty(NSString *path) { - (instancetype)initWithParams:(GPTParams *)params { self = [super init]; - - self->params = params; - // model = llama_init.model; - // ctx = llama_init.context; - // - // if model == nil { - // LOG_ERR("%s: error: unable to load model\n", __func__); - // return 1; - // } - // - os_log_info(OS_LOG_DEFAULT, - "%s: llama threadpool init, n_threads = %d\n", - __func__, params.cpuParams.nThreads); + self->_params = [params copy]; + self->_mutableLastOutput = [[NSMutableString alloc] init]; + if (params.logging) { + os_log_inst = OS_LOG_DEFAULT; + } else { + os_log_inst = OS_LOG_DISABLED; + } + if (!params.modelPath) { + [NSException raise:@"ModelFailure" + format:@"params.modelPath must be defined"]; + } + + os_log_info(os_log_inst, + "%s: llama threadpool init, n_threads = %ld\n", + __func__, static_cast(params.cpuParams.nThreads)); if (params.embedding) { - os_log_error(OS_LOG_DEFAULT, + os_log_error(os_log_inst, R"(************ please use the 'embedding' tool for embedding calculations ************)"); @@ -154,22 +158,29 @@ static BOOL file_is_empty(NSString *path) { } if (params.nCtx != 0 && params.nCtx < 8) { - os_log_info(OS_LOG_DEFAULT, "minimum context size is 8, using minimum size."); + os_log_info(os_log_inst, "minimum context size is 8, using minimum size."); params.nCtx = 8; } if (params.ropeFreqBase != 0) { - os_log_info(OS_LOG_DEFAULT, "changing RoPE frequency base to \(params.ropeFreqBase)"); + os_log_info(os_log_inst, "changing RoPE frequency base to \(params.ropeFreqBase)"); } if (params.ropeFreqScale != 0.0) { - os_log_info(OS_LOG_DEFAULT, "scaling RoPE frequency by \(params.ropeFreqScale)"); + os_log_info(os_log_inst, "scaling RoPE frequency by \(params.ropeFreqScale)"); } llama_backend_init(); llama_numa_init(ggml_numa_strategy(params.numaStrategy)); auto llama_init = common_init_from_params([params params]); - + if (llama_init.context == nil) { + [NSException raise:@"ContextFailure" + format:@"could not load context"]; + } + if (llama_init.model == nil) { + [NSException raise:@"ModelLoadFailure" + format:@"could not load model"]; + } auto tpp_batch = params.cpuParamsBatch.ggmlThreadpoolParams; auto tpp = params.cpuParams.ggmlThreadpoolParams; @@ -179,7 +190,7 @@ static BOOL file_is_empty(NSString *path) { if (tpp != tpp_batch) { threadpool_batch = [tpp_batch threadpool]; if (!threadpool_batch) { - [NSException raise:@"batch threadpool create failed" + [NSException raise:@"ThreadpoolFailure" format:@"batch threadpool create failed"]; } @@ -189,7 +200,7 @@ static BOOL file_is_empty(NSString *path) { GGMLThreadpool *threadpool = [tpp threadpool]; if (!threadpool) { - [NSException raise:@"threadpool create failed" + [NSException raise:@"ThreadpoolFailure" format:@"threadpool create failed"]; } @@ -200,16 +211,16 @@ static BOOL file_is_empty(NSString *path) { n_ctx = [self.ctx nCtx]; // if (n_ctx > n_ctx_train) { - os_log_info(OS_LOG_DEFAULT, "%s: model was trained on only %d context tokens (%d specified)\n", __func__, n_ctx_train, n_ctx); + os_log_info(os_log_inst, "%s: model was trained on only %d context tokens (%d specified)\n", __func__, n_ctx_train, n_ctx); } // print chat template example in conversation mode if (params.conversation) { if (params.enableChatTemplate) { - os_log_info(OS_LOG_DEFAULT, "%s: chat template example:\n%s\n", __func__, + os_log_info(os_log_inst, "%s: chat template example:\n%s\n", __func__, [[self.model formatExample:params.chatTemplate] cStringUsingEncoding:NSUTF8StringEncoding]); } else { - os_log_info(OS_LOG_DEFAULT, "%s: in-suffix/prefix is specified, chat template will be disabled\n", __func__); + os_log_info(os_log_inst, "%s: in-suffix/prefix is specified, chat template will be disabled\n", __func__); } } // print system information @@ -222,11 +233,11 @@ static BOOL file_is_empty(NSString *path) { NSFileManager *fileManager = [NSFileManager defaultManager]; if ([pathSession length] != 0) { - os_log_info(OS_LOG_DEFAULT, "%s: attempting to load saved session from '%s'\n", __func__, [pathSession cStringUsingEncoding:NSUTF8StringEncoding]); + os_log_info(os_log_inst, "%s: attempting to load saved session from '%s'\n", __func__, [pathSession cStringUsingEncoding:NSUTF8StringEncoding]); if (![fileManager fileExistsAtPath:pathSession]) { - os_log_info(OS_LOG_DEFAULT, "%s: session file does not exist, will create.\n", __func__); + os_log_info(os_log_inst, "%s: session file does not exist, will create.\n", __func__); } else if (file_is_empty(pathSession)) { - os_log_info(OS_LOG_DEFAULT,"%s: The session file is empty. A new session will be initialized.\n", __func__); + os_log_info(os_log_inst,"%s: The session file is empty. A new session will be initialized.\n", __func__); } else { // The file exists and is not empty session_tokens.resize(n_ctx); @@ -235,7 +246,7 @@ static BOOL file_is_empty(NSString *path) { [NSException raise:@"SessionLoadFailure" format:@"%s: failed to load session file '%s'\n", __func__, [pathSession cStringUsingEncoding:NSUTF8StringEncoding]]; } session_tokens.resize(n_token_count_out); - os_log_info(OS_LOG_DEFAULT,"%s: loaded a session with prompt size of %d tokens\n", __func__, (int)session_tokens.size()); + os_log_info(os_log_inst,"%s: loaded a session with prompt size of %d tokens\n", __func__, (int)session_tokens.size()); } } @@ -244,7 +255,7 @@ static BOOL file_is_empty(NSString *path) { GGML_ASSERT(![self.model addEOSToken]); } - os_log_debug(OS_LOG_DEFAULT, "n_ctx: %d, add_bos: %d\n", n_ctx, addBOS); + os_log_debug(os_log_inst, "n_ctx: %d, add_bos: %d\n", n_ctx, addBOS); { @@ -252,22 +263,22 @@ static BOOL file_is_empty(NSString *path) { ? [self chat_add_and_format:chat_msgs role:"system" content:[params params].prompt] // format the system prompt in conversation mode : params.prompt; if (params.interactiveFirst || [params.prompt length] > 0 || session_tokens.empty()) { - os_log_debug(OS_LOG_DEFAULT, "tokenize the prompt\n"); + os_log_debug(os_log_inst, "tokenize the prompt\n"); embd_inp = [self.ctx tokenize:prompt addSpecial:true parseSpecial:true]; } else { - os_log_debug(OS_LOG_DEFAULT,"use session tokens\n"); + os_log_debug(os_log_inst,"use session tokens\n"); embd_inp = session_tokens; } - os_log_debug(OS_LOG_DEFAULT,"prompt: \"%s\"\n", [prompt cStringUsingEncoding:NSUTF8StringEncoding]); - os_log_debug(OS_LOG_DEFAULT,"tokens: %s\n", [self.ctx convertTokensToString:embd_inp].c_str()); + os_log_debug(os_log_inst,"prompt: \"%s\"\n", [prompt cStringUsingEncoding:NSUTF8StringEncoding]); + os_log_debug(os_log_inst,"tokens: %s\n", [self.ctx convertTokensToString:embd_inp].c_str()); } // Should not run without any tokens if (embd_inp.empty()) { if (addBOS) { embd_inp.push_back([self.model tokenBOS]); -// LOG_WRN("embd_inp was considered empty and bos was added: %s\n", string_from(ctx, embd_inp).c_str()); + os_log_info(os_log_inst, "embd_inp was considered empty and bos was added: %s\n", [_ctx convertTokensToString:embd_inp].c_str()); } else { [NSException raise:@"InputEmptyError" format:@"input is empty"]; } @@ -303,13 +314,13 @@ static BOOL file_is_empty(NSString *path) { llama_kv_cache_seq_rm([self.ctx cContext], -1, n_matching_session_tokens, -1); } // - // os_log_debug(OS_LOG_DEFAULT, "recalculate the cached logits (check): embd_inp.size() %zu, n_matching_session_tokens %zu, embd_inp.size() %zu, session_tokens.size() %zu\n", + // os_log_debug(os_log_inst, "recalculate the cached logits (check): embd_inp.size() %zu, n_matching_session_tokens %zu, embd_inp.size() %zu, session_tokens.size() %zu\n", // embd_inp.size(), n_matching_session_tokens, embd_inp.size(), session_tokens.size()); // // if we will use the cache for the full prompt without reaching the end of the cache, force // reevaluation of the last token to recalculate the cached logits if (!embd_inp.empty() && n_matching_session_tokens == embd_inp.size() && session_tokens.size() > embd_inp.size()) { -// os_log_debug(OS_LOG_DEFAULT, "recalculate the cached logits (do): session_tokens.resize( %zu )\n", embd_inp.size() - 1); +// os_log_debug(os_log_inst, "recalculate the cached logits (do): session_tokens.resize( %zu )\n", embd_inp.size() - 1); session_tokens.resize(embd_inp.size() - 1); } @@ -331,22 +342,21 @@ static BOOL file_is_empty(NSString *path) { } if (params.verbosePrompt) { -// LOG_INF("%s: prompt: '%s'\n", __func__, params.prompt.c_str()); + os_log_info(os_log_inst, + "%s: prompt: '%s'\n", __func__, [params.prompt cStringUsingEncoding:NSUTF8StringEncoding]); // LOG_INF("%s: number of tokens in prompt = %zu\n", __func__, embd_inp.size()); for (int i = 0; i < (int) embd_inp.size(); i++) { - os_log_info(OS_LOG_DEFAULT, "%6d -> '%s'\n", embd_inp[i], + os_log_info(os_log_inst, "%6d -> '%s'\n", embd_inp[i], [[self.ctx tokenToPiece:embd_inp[i]] cStringUsingEncoding:NSUTF8StringEncoding]); } if (params.nKeep > addBOS) { // LOG_INF("%s: static prompt based on n_keep: '", __func__); for (int i = 0; i < params.nKeep; i++) { - os_log_debug(OS_LOG_DEFAULT, "%s", + os_log_debug(os_log_inst, "%s", [[self.ctx tokenToPiece:embd_inp[i]] cStringUsingEncoding:NSUTF8StringEncoding]); } -// LOG("'\n"); } -// LOG_INF("\n"); } // // // ctrl+C handling @@ -366,55 +376,55 @@ static BOOL file_is_empty(NSString *path) { // } // if (params.interactive) { - os_log_info(OS_LOG_DEFAULT, "%s: interactive mode on.\n", __func__); + os_log_info(os_log_inst, "%s: interactive mode on.\n", __func__); if ([params.antiPrompts count] > 0) { for (NSString *antiprompt in params.antiPrompts) { - os_log_info(OS_LOG_DEFAULT, "Reverse prompt: '%s'\n", [antiprompt cStringUsingEncoding:NSUTF8StringEncoding]); + os_log_info(os_log_inst, "Reverse prompt: '%s'\n", [antiprompt cStringUsingEncoding:NSUTF8StringEncoding]); if (params.verbosePrompt) { auto tmp = [_ctx tokenize:antiprompt addSpecial:false parseSpecial:true]; for (int i = 0; i < (int) tmp.size(); i++) { - os_log_info(OS_LOG_DEFAULT, "%6d -> '%s'\n", tmp[i], [[self.ctx tokenToPiece:tmp[i]] cStringUsingEncoding:NSUTF8StringEncoding]); + os_log_info(os_log_inst, "%6d -> '%s'\n", tmp[i], [[self.ctx tokenToPiece:tmp[i]] cStringUsingEncoding:NSUTF8StringEncoding]); } } } } if (params.inputPrefixBOS) { - os_log_info(OS_LOG_DEFAULT, "Input prefix with BOS\n"); + os_log_info(os_log_inst, "Input prefix with BOS\n"); } if ([params.inputPrefix length] > 0) { - os_log_info(OS_LOG_DEFAULT, "Input prefix: '%s'\n", [params.inputPrefix cStringUsingEncoding:NSUTF8StringEncoding]); + os_log_info(os_log_inst, "Input prefix: '%s'\n", [params.inputPrefix cStringUsingEncoding:NSUTF8StringEncoding]); if (params.verbosePrompt) { auto tmp = [_ctx tokenize:params.inputPrefix addSpecial:true parseSpecial:true]; for (int i = 0; i < (int) tmp.size(); i++) { - os_log_info(OS_LOG_DEFAULT, "%6d -> '%s'\n", + os_log_info(os_log_inst, "%6d -> '%s'\n", tmp[i], [[self.ctx tokenToPiece:tmp[i]] cStringUsingEncoding:NSUTF8StringEncoding]); } } } if ([params.inputSuffix length] > 0) { - os_log_info(OS_LOG_DEFAULT, "Input suffix: '%s'\n", [params.inputSuffix cStringUsingEncoding:NSUTF8StringEncoding]); + os_log_info(os_log_inst, "Input suffix: '%s'\n", [params.inputSuffix cStringUsingEncoding:NSUTF8StringEncoding]); if (params.verbosePrompt) { auto tmp = [_ctx tokenize:params.inputSuffix addSpecial:false parseSpecial:true]; for (int i = 0; i < (int) tmp.size(); i++) { - os_log_info(OS_LOG_DEFAULT, "%6d -> '%s'\n", + os_log_info(os_log_inst, "%6d -> '%s'\n", tmp[i], [[self.ctx tokenToPiece:tmp[i]] cStringUsingEncoding:NSUTF8StringEncoding]); } } } } - smpl = [[GPTSampler alloc] init:_model gptSamplerParams:[params samplerParams]]; - if (!smpl) { + _smpl = [[GPTSampler alloc] init:_model gptSamplerParams:[params samplerParams]]; + if (!_smpl) { [NSException raise:@"SamplingFailure" format:@"failed to initialize sampling subsystem"]; } - os_log_info(OS_LOG_DEFAULT, "sampler seed: %u\n", [smpl seed]); + os_log_info(os_log_inst, "sampler seed: %u\n", [_smpl seed]); // LOG_INF("sampler params: \n%s\n", sparams.print().c_str()); // LOG_INF("sampler chain: %s\n", gpt_sampler_print(smpl).c_str()); // @@ -431,7 +441,7 @@ static BOOL file_is_empty(NSString *path) { GGML_ASSERT(ga_w % ga_n == 0 && "grp_attn_w must be a multiple of grp_attn_n"); // NOLINT //GGML_ASSERT(n_ctx_train % ga_w == 0 && "n_ctx_train must be a multiple of grp_attn_w"); // NOLINT //GGML_ASSERT(n_ctx >= n_ctx_train * ga_n && "n_ctx must be at least n_ctx_train * grp_attn_n"); // NOLINT - os_log_info(OS_LOG_DEFAULT, "self-extend: n_ctx_train = %d, grp_attn_n = %ld, grp_attn_w = %ld\n", n_ctx_train, static_cast(ga_n), static_cast(ga_w)); + os_log_info(os_log_inst, "self-extend: n_ctx_train = %d, grp_attn_n = %ld, grp_attn_w = %ld\n", n_ctx_train, static_cast(ga_n), static_cast(ga_w)); } if (params.interactive) { @@ -454,14 +464,6 @@ static BOOL file_is_empty(NSString *path) { need_to_save_session = [pathSession length] > 0 && n_matching_session_tokens < embd_inp.size(); n_remain = params.nPredict; - // // the first thing we will do is to output the prompt, so set color accordingly - // console::set_display(console::prompt); - // display = params.display_prompt; - // - - - - antiprompt_ids.reserve([params.antiPrompts count]); for (NSString *antiprompt in params.antiPrompts) { antiprompt_ids.emplace_back([self.ctx tokenize:antiprompt addSpecial:false parseSpecial:true]); @@ -486,8 +488,13 @@ static BOOL file_is_empty(NSString *path) { return self; } +// MARK: LastOutput +- (NSString *)lastOutput { + return [_mutableLastOutput copy]; +} + - (void)start:(BlockingLineQueue *)queue { - while ((n_remain != 0 && !is_antiprompt) || params.interactive) { + while ((n_remain != 0 && !is_antiprompt) || _params.interactive) { // predict if (!embd.empty()) { // Note: (n_ctx - 4) here is to match the logic for commandline prompt handling via @@ -500,42 +507,42 @@ static BOOL file_is_empty(NSString *path) { embd.resize(max_embd_size); // console::set_display(console::error); - os_log_error(OS_LOG_DEFAULT, "<>", skipped_tokens, skipped_tokens != 1 ? "s" : ""); + os_log_error(os_log_inst, "<>", skipped_tokens, skipped_tokens != 1 ? "s" : ""); // console::set_display(console::reset); } - if (params.grpAttnN == 1) { + if (_params.grpAttnN == 1) { // infinite text generation via context shifting // if we run out of context: // - take the n_keep first tokens from the original prompt (via n_past) // - take half of the last (n_ctx - n_keep) tokens and recompute the logits in batches if (n_past + (int) embd.size() >= [_ctx nCtx]) { - if (!params.ctxShift) { - os_log_debug(OS_LOG_DEFAULT, "\n\n%s: context full and context shift is disabled => stopping\n", __func__); + if (!_params.ctxShift) { + os_log_debug(os_log_inst, "\n\n%s: context full and context shift is disabled => stopping\n", __func__); break; } else { - if (params.nPredict == -2) { - os_log_debug(OS_LOG_DEFAULT, "\n\n%s: context full and n_predict == -%d => stopping\n", __func__, params.nPredict); + if (_params.nPredict == -2) { + os_log_debug(os_log_inst, "\n\n%s: context full and n_predict == -%d => stopping\n", __func__, _params.nPredict); break; } - const int n_left = n_past - params.nKeep; + const int n_left = n_past - _params.nKeep; const int n_discard = n_left/2; - os_log_debug(OS_LOG_DEFAULT, "context full, swapping: n_past = %d, n_left = %d, n_ctx = %lu, n_keep = %d, n_discard = %d\n", - n_past, n_left, static_cast([_ctx nCtx]), params.nKeep, n_discard); + os_log_debug(os_log_inst, "context full, swapping: n_past = %d, n_left = %d, n_ctx = %lu, n_keep = %d, n_discard = %d\n", + n_past, n_left, static_cast([_ctx nCtx]), _params.nKeep, n_discard); - llama_kv_cache_seq_rm ([self.ctx cContext], 0, params.nKeep , params.nKeep + n_discard); - llama_kv_cache_seq_add([self.ctx cContext], 0, params.nKeep + n_discard, n_past, -n_discard); + llama_kv_cache_seq_rm ([self.ctx cContext], 0, _params.nKeep , _params.nKeep + n_discard); + llama_kv_cache_seq_add([self.ctx cContext], 0, _params.nKeep + n_discard, n_past, -n_discard); n_past -= n_discard; - os_log_debug(OS_LOG_DEFAULT, "after swap: n_past = %d\n", n_past); + os_log_debug(os_log_inst, "after swap: n_past = %d\n", n_past); - os_log_debug(OS_LOG_DEFAULT, "embd: %s\n", [self.ctx convertTokensToString:embd].c_str()); + os_log_debug(os_log_inst, "embd: %s\n", [self.ctx convertTokensToString:embd].c_str()); - os_log_debug(OS_LOG_DEFAULT, "clear session path\n"); + os_log_debug(os_log_inst, "clear session path\n"); [pathSession setString:@""]; } } @@ -546,10 +553,10 @@ static BOOL file_is_empty(NSString *path) { const int bd = (ga_w/ga_n)*(ga_n - 1); const int dd = (ga_w/ga_n) - ib*bd - ga_w; - os_log_debug(OS_LOG_DEFAULT, "\n"); - os_log_debug(OS_LOG_DEFAULT, "shift: [%6ld, %6d] + %6d -> [%6ld, %6d]\n", static_cast(ga_i), n_past, ib*bd, static_cast(ga_i + ib*bd), n_past + ib*bd); - os_log_debug(OS_LOG_DEFAULT, "div: [%6ld, %6ld] / %6ld -> [%6ld, %6ld]\n", static_cast(ga_i + ib*bd), static_cast(ga_i + ib*bd + ga_w), static_cast(ga_n), static_cast((ga_i + ib*bd)/ga_n), static_cast((ga_i + ib*bd + ga_w)/ga_n)); - os_log_debug(OS_LOG_DEFAULT, "shift: [%6ld, %6d] + %6d -> [%6ld, %6d]\n", static_cast(ga_i + ib*bd + ga_w), n_past + ib*bd, dd, static_cast(ga_i + ib*bd + ga_w + dd), n_past + ib*bd + dd); + os_log_debug(os_log_inst, "\n"); + os_log_debug(os_log_inst, "shift: [%6ld, %6d] + %6d -> [%6ld, %6d]\n", static_cast(ga_i), n_past, ib*bd, static_cast(ga_i + ib*bd), n_past + ib*bd); + os_log_debug(os_log_inst, "div: [%6ld, %6ld] / %6ld -> [%6ld, %6ld]\n", static_cast(ga_i + ib*bd), static_cast(ga_i + ib*bd + ga_w), static_cast(ga_n), static_cast((ga_i + ib*bd)/ga_n), static_cast((ga_i + ib*bd + ga_w)/ga_n)); + os_log_debug(os_log_inst, "shift: [%6ld, %6d] + %6d -> [%6ld, %6d]\n", static_cast(ga_i + ib*bd + ga_w), n_past + ib*bd, dd, static_cast(ga_i + ib*bd + ga_w + dd), n_past + ib*bd + dd); [self.ctx kvCacheSeqAdd:0 p0:ga_i p1:n_past delta:ib*bd]; [self.ctx kvCacheSeqDiv:0 p0:ga_i + ib*bd p1:ga_i + ib*bd + ga_w delta:ga_n]; @@ -559,7 +566,7 @@ static BOOL file_is_empty(NSString *path) { ga_i += ga_w/ga_n; - os_log_debug(OS_LOG_DEFAULT, "\nn_past_old = %d, n_past = %d, ga_i = %ld\n\n", n_past + bd, n_past, static_cast(ga_i)); + os_log_debug(os_log_inst, "\nn_past_old = %d, n_past = %d, ga_i = %ld\n\n", n_past + bd, n_past, static_cast(ga_i)); } } @@ -585,13 +592,13 @@ static BOOL file_is_empty(NSString *path) { } } - for (int i = 0; i < (int) embd.size(); i += params.nBatch) { + for (int i = 0; i < (int) embd.size(); i += _params.nBatch) { int n_eval = (int) embd.size() - i; - if (n_eval > params.nBatch) { - n_eval = params.nBatch; + if (n_eval > _params.nBatch) { + n_eval = _params.nBatch; } - os_log_debug(OS_LOG_DEFAULT, "eval: %s\n", [self.ctx convertTokensToString:embd].c_str()); + os_log_debug(os_log_inst, "eval: %s\n", [self.ctx convertTokensToString:embd].c_str()); if ([self.ctx decode:[[LlamaBatch alloc] initWithBatch:llama_batch_get_one(&embd[i], n_eval)] ]) { @@ -600,10 +607,10 @@ static BOOL file_is_empty(NSString *path) { n_past += n_eval; - os_log_debug(OS_LOG_DEFAULT, "n_past = %d\n", n_past); + os_log_debug(os_log_inst, "n_past = %d\n", n_past); // Display total tokens alongside total time - if (params.nPrint > 0 && n_past % params.nPrint == 0) { - os_log_debug(OS_LOG_DEFAULT, "\n\033[31mTokens consumed so far = %d / %lu \033[0m\n", n_past, static_cast([self.ctx nCtx])); + if (_params.nPrint > 0 && n_past % _params.nPrint == 0) { + os_log_debug(os_log_inst, "\n\033[31mTokens consumed so far = %d / %lu \033[0m\n", n_past, static_cast([self.ctx nCtx])); } } @@ -617,19 +624,19 @@ static BOOL file_is_empty(NSString *path) { if ((int) embd_inp.size() <= n_consumed && !isInteracting) { // optionally save the session on first sample (for faster prompt loading next time) - if ([pathSession length] > 0 && need_to_save_session && !params.promptCacheRO) { + if ([pathSession length] > 0 && need_to_save_session && !_params.promptCacheRO) { need_to_save_session = false; [self.ctx saveStateFile:pathSession tokens:session_tokens.data() nTokenCount:session_tokens.size()]; // llama_state_save_file(ctx, path_session.c_str(), session_tokens.data(), session_tokens.size()); - os_log_debug(OS_LOG_DEFAULT, "saved session to %s\n", [pathSession cStringUsingEncoding:NSUTF8StringEncoding]); + os_log_debug(os_log_inst, "saved session to %s\n", [pathSession cStringUsingEncoding:NSUTF8StringEncoding]); } - const llama_token idToken = [smpl sample:self.ctx index:-1]; + const llama_token idToken = [_smpl sample:self.ctx index:-1]; - [smpl accept:idToken acceptGrammar:true]; + [_smpl accept:idToken acceptGrammar:true]; - // os_log_debug(OS_LOG_DEFAULT, "last: %s\n", string_from(ctx, smpl->prev.to_vector()).c_str()); + // os_log_debug(os_log_inst, "last: %s\n", string_from(ctx, smpl->prev.to_vector()).c_str()); embd.push_back(idToken); @@ -639,19 +646,19 @@ static BOOL file_is_empty(NSString *path) { // decrement remaining sampling budget --n_remain; - os_log_debug(OS_LOG_DEFAULT, "n_remain: %d\n", n_remain); + os_log_debug(os_log_inst, "n_remain: %d\n", n_remain); } else { // some user input remains from prompt or interaction, forward it to processing - os_log_debug(OS_LOG_DEFAULT, "embd_inp.size(): %d, n_consumed: %d\n", (int) embd_inp.size(), n_consumed); + os_log_debug(os_log_inst, "embd_inp.size(): %d, n_consumed: %d\n", (int) embd_inp.size(), n_consumed); while ((int) embd_inp.size() > n_consumed) { embd.push_back(embd_inp[n_consumed]); // push the prompt in the sampling context in order to apply repetition penalties later // for the prompt, we don't apply grammar rules - [smpl accept:embd_inp[n_consumed] acceptGrammar:false]; + [_smpl accept:embd_inp[n_consumed] acceptGrammar:false]; ++n_consumed; - if ((int) embd.size() >= params.nBatch) { + if ((int) embd.size() >= _params.nBatch) { break; } } @@ -662,10 +669,10 @@ static BOOL file_is_empty(NSString *path) { // std::cout<< "DISPLAYING TEXT" << std::endl; for (auto idToken : embd) { - NSString *token_str = [self.ctx tokenToPiece:idToken special:params.special]; + NSString *token_str = [self.ctx tokenToPiece:idToken special:_params.special]; // Console/Stream Output - os_log_info(OS_LOG_DEFAULT, "%s", [token_str cStringUsingEncoding:NSUTF8StringEncoding]); + os_log_info(os_log_inst, "%s", [token_str cStringUsingEncoding:NSUTF8StringEncoding]); // Record Displayed Tokens To Log // Note: Generated tokens are created one by one hence this check @@ -678,6 +685,9 @@ static BOOL file_is_empty(NSString *path) { output_tokens.push_back(idToken); output_ss << [token_str cStringUsingEncoding:NSUTF8StringEncoding]; last_output_ss << [token_str cStringUsingEncoding:NSUTF8StringEncoding]; + [self willChangeValueForKey:@"lastOutput"]; + [_mutableLastOutput appendString:token_str]; + [self didChangeValueForKey:@"lastOutput"]; } } @@ -698,23 +708,23 @@ static BOOL file_is_empty(NSString *path) { // if not currently processing queued inputs; if ((int) embd_inp.size() <= n_consumed) { // check for reverse prompt in the last n_prev tokens - if ([params.antiPrompts count] > 0) { + if ([_params.antiPrompts count] > 0) { const int n_prev = 32; - NSString *last_output = [smpl previousString:self.ctx n:n_prev]; + NSString *last_output = [_smpl previousString:self.ctx n:n_prev]; is_antiprompt = false; // Check if each of the reverse prompts appears at the end of the output. // If we're not running interactively, the reverse prompt might be tokenized with some following characters // so we'll compensate for that by widening the search window a bit. - for (NSString *antiprompt in params.antiPrompts) { - size_t extra_padding = params.interactive ? 0 : 2; + for (NSString *antiprompt in _params.antiPrompts) { + size_t extra_padding = _params.interactive ? 0 : 2; size_t search_start_pos = [last_output length] > static_cast([antiprompt length] + extra_padding) ? [last_output length] - static_cast([antiprompt length] + extra_padding) : 0; // TODO: Check if correct if ([last_output rangeOfString:antiprompt options:0 range:NSMakeRange(search_start_pos, last_output.length - search_start_pos)].location != NSNotFound) { - if (params.interactive) { + if (_params.interactive) { isInteracting = true; } is_antiprompt = true; @@ -723,10 +733,10 @@ static BOOL file_is_empty(NSString *path) { } // check for reverse prompt using special tokens - llama_token last_token = [smpl last]; + llama_token last_token = [_smpl last]; for (std::vector ids : antiprompt_ids) { if (ids.size() == 1 && last_token == ids[0]) { - if (params.interactive) { + if (_params.interactive) { isInteracting = true; } is_antiprompt = true; @@ -735,25 +745,25 @@ static BOOL file_is_empty(NSString *path) { } if (is_antiprompt) { - os_log_debug(OS_LOG_DEFAULT, "found antiprompt: %s\n", [last_output cStringUsingEncoding:NSUTF8StringEncoding]); + os_log_debug(os_log_inst, "found antiprompt: %s\n", [last_output cStringUsingEncoding:NSUTF8StringEncoding]); } } // deal with end of generation tokens in interactive mode - if ([self.model tokenIsEOG:[smpl last]]) { - os_log_debug(OS_LOG_DEFAULT, "found an EOG token\n"); + if ([self.model tokenIsEOG:[_smpl last]]) { + os_log_debug(os_log_inst, "found an EOG token\n"); - if (params.interactive) { - if ([[params antiPrompts] count] > 0) { + if (_params.interactive) { + if ([[_params antiPrompts] count] > 0) { // tokenize and inject first reverse prompt - const auto first_antiprompt = [self.ctx tokenize:params.antiPrompts[0] addSpecial:false parseSpecial:true]; + const auto first_antiprompt = [self.ctx tokenize:_params.antiPrompts[0] addSpecial:false parseSpecial:true]; embd_inp.insert(embd_inp.end(), first_antiprompt.begin(), first_antiprompt.end()); is_antiprompt = true; } - if (params.enableChatTemplate) { + if (_params.enableChatTemplate) { [self chat_add_and_format:chat_msgs role:"assistant" content:assistant_ss.str()]; @@ -764,32 +774,32 @@ static BOOL file_is_empty(NSString *path) { } // if current token is not EOG, we add it to current assistant message - if (params.conversation) { - const auto idToken = [smpl last]; + if (_params.conversation) { + const auto idToken = [_smpl last]; assistant_ss << [[self.ctx tokenToPiece:idToken special:false] cStringUsingEncoding:NSUTF8StringEncoding]; } if (n_past > 0 && isInteracting) { - os_log_debug(OS_LOG_DEFAULT, "waiting for user input\n"); + os_log_debug(os_log_inst, "waiting for user input\n"); - if (params.conversation) { + if (_params.conversation) { // osLog_("\n> "); } - if (params.inputPrefixBOS) { - os_log_debug(OS_LOG_DEFAULT, "adding input prefix BOS token\n"); + if (_params.inputPrefixBOS) { + os_log_debug(os_log_inst, "adding input prefix BOS token\n"); embd_inp.push_back([self.model tokenBOS]); } std::string buffer; - if ([params.inputPrefix length] > 0 && !params.conversation) { - os_log_debug(OS_LOG_DEFAULT, "appending input prefix: '%s'\n", [params.inputPrefix cStringUsingEncoding:NSUTF8StringEncoding]); - os_log_info(OS_LOG_DEFAULT, "%s", [params.inputPrefix cStringUsingEncoding:NSUTF8StringEncoding]); + if ([_params.inputPrefix length] > 0 && !_params.conversation) { + os_log_debug(os_log_inst, "appending input prefix: '%s'\n", [_params.inputPrefix cStringUsingEncoding:NSUTF8StringEncoding]); + os_log_info(os_log_inst, "%s", [_params.inputPrefix cStringUsingEncoding:NSUTF8StringEncoding]); } // color user input only // console::set_display(console::user_input); - display = params.displayPrompt; + display = _params.displayPrompt; std::string line; // bool another_line = true; @@ -806,8 +816,12 @@ static BOOL file_is_empty(NSString *path) { auto str = last_output_ss.str(); last_output_ss.str(""); [queue addOutputLine:[NSString stringWithCString:str.c_str() encoding:NSUTF8StringEncoding]]; + [self willChangeValueForKey:@"lastOutput"]; + _mutableLastOutput = [[NSMutableString alloc] init]; + [self didChangeValueForKey:@"lastOutput"]; + } - + buffer = [[queue inputLine] cStringUsingEncoding:NSUTF8StringEncoding]; // do { // another_line = console::readline(line, params.multiline_input); @@ -822,34 +836,34 @@ static BOOL file_is_empty(NSString *path) { // Entering a empty line lets the user pass control back if (buffer.length() > 1) { // append input suffix if any - if ([params.inputSuffix length] > 0 && !params.conversation) { - os_log_debug(OS_LOG_DEFAULT, "appending input suffix: '%s'\n", [params.inputSuffix cStringUsingEncoding:NSUTF8StringEncoding]); - os_log_info(OS_LOG_DEFAULT, "%s", [params.inputSuffix cStringUsingEncoding:NSUTF8StringEncoding]); + if ([[self params].inputSuffix length] > 0 && !_params.conversation) { + os_log_debug(os_log_inst, "appending input suffix: '%s'\n", [_params.inputSuffix cStringUsingEncoding:NSUTF8StringEncoding]); + os_log_info(os_log_inst, "%s", [_params.inputSuffix cStringUsingEncoding:NSUTF8StringEncoding]); } - os_log_debug(OS_LOG_DEFAULT, "buffer: '%s'\n", buffer.c_str()); + os_log_debug(os_log_inst, "buffer: '%s'\n", buffer.c_str()); const size_t original_size = embd_inp.size(); - if (params.escapeSequences) { + if (_params.escapeSequences) { string_process_escapes(buffer); } - bool format_chat = params.conversation && params.enableChatTemplate; + bool format_chat = _params.conversation && _params.enableChatTemplate; std::string user_inp = format_chat ? [[self chat_add_and_format:chat_msgs role:"user" content:std::move(buffer)] cStringUsingEncoding:NSUTF8StringEncoding] : std::move(buffer); // TODO: one inconvenient of current chat template implementation is that we can't distinguish between user input and special tokens (prefix/postfix) - const auto line_pfx = [self.ctx tokenize:params.inputPrefix addSpecial:false parseSpecial:true]; + const auto line_pfx = [self.ctx tokenize:_params.inputPrefix addSpecial:false parseSpecial:true]; const auto line_inp = [self.ctx tokenize:[NSString stringWithCString:user_inp.c_str() encoding:NSUTF8StringEncoding] addSpecial:false parseSpecial:format_chat]; - const auto line_sfx = [self.ctx tokenize:params.inputSuffix + const auto line_sfx = [self.ctx tokenize:_params.inputSuffix addSpecial:false parseSpecial:true]; - os_log_debug(OS_LOG_DEFAULT, "input tokens: %s\n", [self.ctx convertTokensToString:line_inp].c_str()); + os_log_debug(os_log_inst, "input tokens: %s\n", [self.ctx convertTokensToString:line_inp].c_str()); // if user stop generation mid-way, we must add EOT to finish model's last response if (need_insert_eot && format_chat) { @@ -872,9 +886,9 @@ static BOOL file_is_empty(NSString *path) { assistant_ss.str(""); n_remain -= line_inp.size(); - os_log_debug(OS_LOG_DEFAULT, "n_remain: %d\n", n_remain); + os_log_debug(os_log_inst, "n_remain: %d\n", n_remain); } else { - os_log_debug(OS_LOG_DEFAULT, "empty line, passing control back\n"); + os_log_debug(os_log_inst, "empty line, passing control back\n"); } input_echo = false; // do not echo this again @@ -882,22 +896,22 @@ static BOOL file_is_empty(NSString *path) { if (n_past > 0) { if (isInteracting) { - [smpl reset]; + [_smpl reset]; } isInteracting = false; } } // end of generation - if (!embd.empty() && [self.model tokenIsEOG:embd.back()] && !(params.interactive)) { - os_log_info(OS_LOG_DEFAULT, " [end of text]\n"); + if (!embd.empty() && [self.model tokenIsEOG:embd.back()] && !(_params.interactive)) { + os_log_info(os_log_inst, " [end of text]\n"); break; } // In interactive mode, respect the maximum number of tokens and drop back to user input when reached. // We skip this logic when n_predict == -1 (infinite) or -2 (stop at context size). - if (params.interactive && n_remain <= 0 && params.nPredict >= 0) { - n_remain = params.nPredict; + if (_params.interactive && n_remain <= 0 && _params.nPredict >= 0) { + n_remain = _params.nPredict; isInteracting = true; } } diff --git a/objc/include/CPUParams.h b/objc/include/CPUParams.h new file mode 100644 index 000000000..e9925c134 --- /dev/null +++ b/objc/include/CPUParams.h @@ -0,0 +1,30 @@ +#ifndef CPUParams_h +#define CPUParams_h + +typedef NS_ENUM(NSUInteger, GGMLSchedPriority); + +@class GGMLThreadpoolParams; + +@interface CPUParams : NSObject + +/// Number of threads to use +@property (nonatomic, assign) NSInteger nThreads; + +/// Default: any CPU +@property (nonatomic, assign) BOOL maskValid; +/// Scheduling priority +@property (nonatomic, assign) GGMLSchedPriority priority; +/// Use strict CPU placement +@property (nonatomic, assign) BOOL strictCPU; +/// Polling (busywait) level (0 - no polling, 100 - mostly polling) +@property (nonatomic, assign) NSUInteger poll; + +// Custom methods to access or manipulate the cpumask array +- (BOOL)getCpuMaskAtIndex:(NSUInteger)index; +- (void)setCpuMask:(BOOL)value atIndex:(NSUInteger)index; + +- (GGMLThreadpoolParams *)ggmlThreadpoolParams; + +@end + +#endif /* Header_h */ diff --git a/objc/include/CPUParams_Private.hpp b/objc/include/CPUParams_Private.hpp new file mode 100644 index 000000000..c8248f1a0 --- /dev/null +++ b/objc/include/CPUParams_Private.hpp @@ -0,0 +1,15 @@ +#ifndef CPUParams_Private_hpp +#define CPUParams_Private_hpp + +#import "CPUParams.h" +#import "../../common/common.h" + +@interface CPUParams() { + cpu_params *params; +} + +- (instancetype)initWithParams:(cpu_params&)params; + +@end + +#endif /* CPUParams_Private_hpp */ diff --git a/objc/include/GGMLThreadpool.h b/objc/include/GGMLThreadpool.h new file mode 100644 index 000000000..4dd46c250 --- /dev/null +++ b/objc/include/GGMLThreadpool.h @@ -0,0 +1,29 @@ +#ifndef GGMLThreadpool_h +#define GGMLThreadpool_h + +typedef NS_ENUM(NSUInteger, GGMLSchedPriority) { + GGMLSchedPriorityNormal = 0, // Normal priority + GGMLSchedPriorityMedium = 1, // Medium priority + GGMLSchedPriorityHigh = 2, // High priority + GGMLSchedPriorityRealtime = 3 // Realtime priority +}; + + +@interface GGMLThreadpool : NSObject +@end + +@interface GGMLThreadpoolParams : NSObject + +@property (nonatomic, assign) int nThreads; +@property (nonatomic, assign) GGMLSchedPriority priority; +@property (nonatomic, assign) uint32_t poll; +@property (nonatomic, assign) BOOL strictCPU; +@property (nonatomic, assign) BOOL paused; + +- (BOOL)getCpuMaskAtIndex:(NSUInteger)index; +- (void)setCpuMask:(BOOL)value atIndex:(NSUInteger)index; +- (GGMLThreadpool *)threadpool; + +@end + +#endif /* GGMLThreadpool_h */ diff --git a/objc/include/GGMLThreadpool_Private.hpp b/objc/include/GGMLThreadpool_Private.hpp new file mode 100644 index 000000000..66f459e84 --- /dev/null +++ b/objc/include/GGMLThreadpool_Private.hpp @@ -0,0 +1,22 @@ +#ifndef GGMLThreadpool_Private_hpp +#define GGMLThreadpool_Private_hpp + +#import "GGMLThreadpool.h" +#import "../../common/common.h" + +@interface GGMLThreadpool() { + ggml_threadpool *threadpool; +} + +- (instancetype)initWithThreadpool:(ggml_threadpool *)threadpool; +- (ggml_threadpool *)threadpool; + +@end + +@interface GGMLThreadpoolParams() + +- (instancetype)initWithParams:(ggml_threadpool_params&&)params; + +@end + +#endif /* GGMLThreadpool_Private_hpp */ diff --git a/objc/include/GPTParams.h b/objc/include/GPTParams.h index 3fe19b1c4..b20449c34 100644 --- a/objc/include/GPTParams.h +++ b/objc/include/GPTParams.h @@ -4,48 +4,7 @@ @class LlamaModelParams; @class LlamaContextParams; @class GGMLThreadpool; - -// Define the ggml_sched_priority enum -typedef NS_ENUM(NSInteger, GGMLSchedPriority) { - GGMLSchedPriorityNormal = 0, // Normal priority - GGMLSchedPriorityMedium = 1, // Medium priority - GGMLSchedPriorityHigh = 2, // High priority - GGMLSchedPriorityRealtime = 3 // Realtime priority -}; - -@interface GGMLThreadpoolParams : NSObject - -@property (nonatomic, assign) int nThreads; -@property (nonatomic, assign) GGMLSchedPriority priority; -@property (nonatomic, assign) uint32_t poll; -@property (nonatomic, assign) BOOL strictCPU; -@property (nonatomic, assign) BOOL paused; - -// Custom access methods for the cpumask array -- (BOOL)getCpuMaskAtIndex:(NSUInteger)index; -- (void)setCpuMask:(BOOL)value atIndex:(NSUInteger)index; -- (GGMLThreadpool *)threadpool; - -@end - -@interface GGMLThreadpool : NSObject -@end - -@interface CPUParams : NSObject - -// Properties -@property (nonatomic, assign) int nThreads; -@property (nonatomic, assign) BOOL maskValid; -@property (nonatomic, assign) GGMLSchedPriority priority; -@property (nonatomic, assign) BOOL strictCPU; -@property (nonatomic, assign) uint32_t poll; - -// Custom methods to access or manipulate the cpumask array -- (BOOL)getCpuMaskAtIndex:(NSUInteger)index; -- (void)setCpuMask:(BOOL)value atIndex:(NSUInteger)index; -- (GGMLThreadpoolParams *)ggmlThreadpoolParams; - -@end +@class CPUParams; @interface GPTSamplerParams : NSObject @@ -72,13 +31,10 @@ typedef NS_ENUM(NSInteger, GGMLSchedPriority) { @property (nonatomic, assign) BOOL penalizeNl; @property (nonatomic, assign) BOOL ignoreEos; @property (nonatomic, assign) BOOL noPerf; +@property (nonatomic, strong) NSArray *samplers; +@property (nonatomic, copy) NSString *grammar; +@property (nonatomic, strong) NSArray *logitBias; -// Arrays and Strings -@property (nonatomic, strong) NSArray *samplers; // Samplers mapped to NSArray of NSNumber (for enums) -@property (nonatomic, copy) NSString *grammar; // Grammar as NSString -@property (nonatomic, strong) NSArray *logitBias; // Logit biases mapped to NSArray of NSNumber - -// Method to print the parameters into a string - (NSString *)print; @end @@ -98,7 +54,7 @@ typedef NS_ENUM(NSInteger, GGMLSchedPriority) { @property (nonatomic, assign) int32_t nGpuLayers; @property (nonatomic, assign) int32_t nGpuLayersDraft; @property (nonatomic, assign) int32_t mainGpu; -@property (nonatomic, strong) NSMutableArray *tensorSplit; // Fixed-size array, stays the same +@property (nonatomic, strong) NSArray *tensorSplit; @property (nonatomic, assign) int32_t grpAttnN; @property (nonatomic, assign) int32_t grpAttnW; @property (nonatomic, assign) int32_t nPrint; @@ -111,13 +67,11 @@ typedef NS_ENUM(NSInteger, GGMLSchedPriority) { @property (nonatomic, assign) int32_t yarnOrigCtx; @property (nonatomic, assign) float defragThold; -// You need to replace your C++ struct "cpu_params" with an Objective-C class or struct accordingly @property (nonatomic, strong) CPUParams *cpuParams; @property (nonatomic, strong) CPUParams *cpuParamsBatch; @property (nonatomic, strong) CPUParams *draftCpuParams; @property (nonatomic, strong) CPUParams *draftCpuParamsBatch; -// Callbacks (assuming they are blocks in Objective-C) @property (nonatomic, copy) void (^cbEval)(void *); @property (nonatomic, assign) void *cbEvalUserData; @@ -149,12 +103,10 @@ typedef NS_ENUM(NSInteger, GGMLSchedPriority) { @property (nonatomic, copy) NSString *logitsFile; @property (nonatomic, copy) NSString *rpcServers; -// Arrays in Objective-C are represented with `NSArray` @property (nonatomic, strong) NSArray *inputFiles; @property (nonatomic, strong) NSArray *antiPrompts; @property (nonatomic, strong) NSArray *kvOverrides; -// Boolean values (in Objective-C, use `BOOL`) @property (nonatomic, assign) BOOL loraInitWithoutApply; @property (nonatomic, strong) NSArray *loraAdapters; @property (nonatomic, strong) NSArray *controlVectors; @@ -257,6 +209,8 @@ typedef NS_ENUM(NSInteger, GGMLSchedPriority) { @property (nonatomic, assign) BOOL ctxShift; // context shift on inifinite text generation @property (nonatomic, assign) BOOL displayPrompt; // print prompt before generation +@property (nonatomic, assign) BOOL logging; // print logging + @end #endif /* GPTParams_h */ diff --git a/objc/include/GPTParams_Private.hpp b/objc/include/GPTParams_Private.hpp index 3285d6641..1c25ccb81 100644 --- a/objc/include/GPTParams_Private.hpp +++ b/objc/include/GPTParams_Private.hpp @@ -5,11 +5,6 @@ #import "ggml.h" #import "../../common/common.h" -@interface GGMLThreadpool() - -- (ggml_threadpool *)threadpool; - -@end @interface GPTParams() diff --git a/objc/include/GPTSampler.h b/objc/include/GPTSampler.h index 317ae6cda..9fe29ee8f 100644 --- a/objc/include/GPTSampler.h +++ b/objc/include/GPTSampler.h @@ -38,14 +38,14 @@ typedef int32_t LlamaToken; index:(NSInteger) index grammarFirst:(BOOL)grammarFirst; -// if accept_grammar is true, the token is accepted both by the sampling chain and the grammar +/// If accept_grammar is true, the token is accepted both by the sampling chain and the grammar - (void)accept:(LlamaToken)token acceptGrammar:(BOOL)acceptGrammar; -// get a string representation of the last accepted tokens +/// Get a string representation of the last accepted tokens - (NSString *)previousString:(LlamaContext *)context n:(NSInteger)n; -// get the last accepted token +/// Get the last accepted token - (LlamaToken)last; - (void)reset; diff --git a/objc/include/LlamaBatch.h b/objc/include/LlamaBatch.h index f5354ba1e..6af5627be 100644 --- a/objc/include/LlamaBatch.h +++ b/objc/include/LlamaBatch.h @@ -5,15 +5,15 @@ typedef NSInteger LlamaSequenceId; typedef NSInteger LlamaPosition; typedef int32_t LlamaToken; -// Input data for llama_decode -// A llama_batch object can contain input about one or many sequences -// The provided arrays (i.e. token, embd, pos, etc.) must have size of n_tokens -// -// - token : the token ids of the input (used when embd is NULL) -// - embd : token embeddings (i.e. float vector of size n_embd) (used when token is NULL) -// - pos : the positions of the respective token in the sequence -// - seq_id : the sequence to which the respective token belongs -// - logits : if zero, the logits (and/or the embeddings) for the respective token will not be output +/// Input data for llama_decode +/// A llama_batch object can contain input about one or many sequences +/// The provided arrays (i.e. token, embd, pos, etc.) must have size of n_tokens +/// +/// - token : the token ids of the input (used when embd is NULL) +/// - embd : token embeddings (i.e. float vector of size n_embd) (used when token is NULL) +/// - pos : the positions of the respective token in the sequence +/// - seq_id : the sequence to which the respective token belongs +/// - logits : if zero, the logits (and/or the embeddings) for the respective token will not be output @interface LlamaBatch : NSObject @property (nonatomic, assign) NSInteger nTokens; diff --git a/objc/include/LlamaObjC.h b/objc/include/LlamaObjC.h index d7fb1b139..0573133bb 100644 --- a/objc/include/LlamaObjC.h +++ b/objc/include/LlamaObjC.h @@ -2,6 +2,8 @@ #define LlamaObjC_h #include +#include +#include #include #include #include diff --git a/objc/include/LlamaSession.h b/objc/include/LlamaSession.h index 45d2c5eea..2f7005d50 100644 --- a/objc/include/LlamaSession.h +++ b/objc/include/LlamaSession.h @@ -14,10 +14,11 @@ @end -@interface LlamaSession : NSObject +NS_REFINED_FOR_SWIFT @interface LlamaSession : NSObject @property (nonatomic, strong) LlamaModel *model; @property (nonatomic, strong) LlamaContext *ctx; +@property (nonatomic, strong, readonly) NSString *lastOutput; - (instancetype)initWithParams:(GPTParams *)params; - (void)start:(BlockingLineQueue *)queue; diff --git a/objc/include/LlamaSession_Private.hpp b/objc/include/LlamaSession_Private.hpp index 7e3b0243f..e95e6edba 100644 --- a/objc/include/LlamaSession_Private.hpp +++ b/objc/include/LlamaSession_Private.hpp @@ -3,8 +3,14 @@ #import "LlamaSession.h" +@class GPTSampler; + @interface LlamaSession() +@property (atomic, strong) NSMutableString *mutableLastOutput; +@property (nonatomic, strong) GPTParams *params; +@property (nonatomic, strong) GPTSampler *smpl; + @end #endif /* LlamaSession_Private_hpp */ diff --git a/swift/JSONSchema/Grammar.swift b/swift/JSONSchema/Grammar.swift index daa2e3dbe..3dd5273e5 100644 --- a/swift/JSONSchema/Grammar.swift +++ b/swift/JSONSchema/Grammar.swift @@ -24,8 +24,6 @@ public class SchemaConverter { } private func formatLiteral(_ literal: Any) -> String { -// let escaped = GRAMMAR_LITERAL_ESCAPES.reduce("\(literal)", { -// let regex = Regex("[\r\n\"]") let escaped = GRAMMAR_LITERAL_ESCAPES.reduce("\(literal)") { $0.replacingOccurrences(of: $1.key, with: $1.value) } diff --git a/swift/JSONSchema/JSONSchema.swift b/swift/JSONSchema/JSONSchema.swift index 0ed76e65a..b3b3269c2 100644 --- a/swift/JSONSchema/JSONSchema.swift +++ b/swift/JSONSchema/JSONSchema.swift @@ -57,6 +57,28 @@ public struct _JSONFunctionSchema: Codable { self.enum = nil } + public init(type: Int.Type, description: String?) { + self.type = "integer" + self.description = description + self.items = nil + self.enum = nil + } + + + public init(type: Double.Type, description: String?) { + self.type = "number" + self.description = description + self.items = nil + self.enum = nil + } + + public init(type: Bool.Type, description: String?) { + self.type = "boolean" + self.description = description + self.items = nil + self.enum = nil + } + public init(type: T.Type, description: String?) where T: RawRepresentable, T: StringProtocol { self.type = "string" @@ -146,6 +168,14 @@ extension Double : JSONSchemaConvertible { ] } } +extension Bool : JSONSchemaConvertible { + public static var type: String { "boolean" } + public static var jsonSchema: [String: Any] { + [ + "type": "boolean" + ] + } +} extension Date : JSONSchemaConvertible { public static var type: String { "string" } diff --git a/swift/LlamaKit/LlamaKit.swift b/swift/LlamaKit/LlamaKit.swift index 8f4bd9f89..784edb19a 100644 --- a/swift/LlamaKit/LlamaKit.swift +++ b/swift/LlamaKit/LlamaKit.swift @@ -7,11 +7,63 @@ public protocol DynamicCallable: Sendable { func dynamicallyCall(withKeywordArguments args: [String: Any]) async throws -> String } +public enum AnyDecodable: Decodable { + case string(String) + case int(Int) + case double(Double) + case bool(Bool) + case null + // Add other cases as needed + + // Initializers for each type + init(_ value: String) { + self = .string(value) + } + + init(_ value: Int) { + self = .int(value) + } + + init(_ value: Double) { + self = .double(value) + } + + init(_ value: Bool) { + self = .bool(value) + } + + init() { + self = .null + } + + // Decodable conformance + public init(from decoder: Decoder) throws { + let container = try decoder.singleValueContainer() + + if container.decodeNil() { + self = .null + } else if let intValue = try? container.decode(Int.self) { + self = .int(intValue) + } else if let doubleValue = try? container.decode(Double.self) { + self = .double(doubleValue) + } else if let boolValue = try? container.decode(Bool.self) { + self = .bool(boolValue) + } else if let stringValue = try? container.decode(String.self) { + self = .string(stringValue) + } else { + let context = DecodingError.Context( + codingPath: decoder.codingPath, + debugDescription: "Cannot decode AnyDecodable" + ) + throw DecodingError.typeMismatch(AnyDecodable.self, context) + } + } +} struct ToolCall: Decodable { let id: Int let name: String - let arguments: [String: String] + let arguments: [String: AnyDecodable] } struct ToolResponse: Encodable { @@ -23,10 +75,13 @@ struct ToolResponse: Encodable { /// Standard chat session for a given LLM. public actor LlamaChatSession { private let queue = BlockingLineQueue() - private let session: LlamaObjC.LlamaSession + private let session: __LlamaSession + /// Initialize the session + /// - parameter params: common parameters to initialize the session + /// - parameter flush: whether or not to flush the initial prompt, reading initial output public init(params: GPTParams, flush: Bool = true) async throws { - session = LlamaObjC.LlamaSession(params: params) + self.session = __LlamaSession(params: params) Task.detached { [session, queue] in session.start(queue) } @@ -36,7 +91,36 @@ public actor LlamaChatSession { _ = queue.outputLine() } - public func chat(message: String) async -> String { + /// Create a new inference stream for a given message + /// - parameter message: The message to receive an inference for. + /// - returns: A stream of output from the LLM. + public func inferenceStream(message: String) async -> AsyncStream { + queue.addInputLine(message) + var observationToken: NSKeyValueObservation? + return AsyncStream { stream in + observationToken = self.session.observe(\.lastOutput, options: [.new, .old]) { session, change in + guard let newValue = change.newValue, + let oldValue = change.oldValue else { + return stream.finish() + } + var delta = "" + for change in newValue!.difference(from: oldValue!) { + switch change { + case .remove(_, _, _): + return stream.finish() + case .insert(_, let element, _): + delta.append(element) + } + } + stream.yield(delta) + } + stream.onTermination = { [observationToken] _ in + observationToken?.invalidate() + } + } + } + + public func infer(message: String) async -> String { queue.addInputLine(message) return queue.outputLine() } @@ -46,15 +130,15 @@ public actor LlamaChatSession { public actor LlamaSession { private let session: LlamaChatSession - public init(params: GPTParams) async throws { + public init(params: GPTParams, flush: Bool = true) async throws { let converter = SchemaConverter(propOrder: []) _ = converter.visit(schema: T.jsonSchema, name: nil) params.samplerParams.grammar = converter.formatGrammar() - session = try await LlamaChatSession(params: params) + session = try await LlamaChatSession(params: params, flush: flush) } public func chat(message: String) async throws -> T { - let output = await session.chat(message: message).data(using: .utf8)! + let output = await session.infer(message: message).data(using: .utf8)! return try JSONDecoder().decode(T.self, from: output) } } @@ -117,11 +201,14 @@ public actor LlamaToolSession { self.tools["getIpAddress"] = (GetIpAddress(), ipFnSchema) let encoded = try JSONEncoder().encode(self.tools.values.map(\.1)) let prompt = """ + \(params.prompt ?? "") + You are a function calling AI model. You are provided with function signatures within XML tags. You may call one or more functions to assist with the user query. Don't make assumptions about what values to plug into functions. For each function call return a json object with function name and arguments within XML tags as follows: {"name": ,"arguments": } + Feel free to chain tool calls, e.g., if you need the user's location to find points of interest near them, fetch the user's location first. The first call you will be asked to warm up is to get the user's IP address. Here are the available tools: \(String(data: encoded, encoding: .utf8)!) <|eot_id|> """ @@ -131,7 +218,7 @@ public actor LlamaToolSession { params.inputPrefix = "<|start_header_id|>user<|end_header_id|>"; params.inputSuffix = "<|eot_id|><|start_header_id|>assistant<|end_header_id|>"; session = try await LlamaChatSession(params: params, flush: false) - let fn = await session.chat(message: "What is my IP address?") + let fn = await session.infer(message: "What is my IP address?") let toolCall = try JSONDecoder().decode(ToolCall.self, from: fn.data(using: .utf8)!) guard let tool = self.tools[toolCall.name] else { fatalError() @@ -139,7 +226,7 @@ public actor LlamaToolSession { let resp = try await tool.0.dynamicallyCall(withKeywordArguments: toolCall.arguments) print(resp) - let output = await session.chat(message: """ + let output = await session.infer(message: """ {"id": \(toolCall.id), result: \(resp)} @@ -147,39 +234,58 @@ public actor LlamaToolSession { print(output) } - public func chat(message: String) async throws -> String { - var nxt = await session.chat(message: message) - let fn = nxt - // try to see if the output is a function call + private func callTool(_ call: String) async -> String? { + var nxt: String? do { - let toolCall = try JSONDecoder().decode(ToolCall.self, from: fn.data(using: .utf8)!) + let toolCall = try JSONDecoder().decode(ToolCall.self, from: call.data(using: .utf8)!) guard let tool = tools[toolCall.name] else { fatalError() } + // TODO: tool call decode is allowed to fail but the code below is not let callable = tool.0 - let resp = try await callable.dynamicallyCall(withKeywordArguments: toolCall.arguments) - print("tool response: \(resp)") - nxt = await session.chat(message: """ - - {"id": \(toolCall.id), result: \(resp)} - - """) - print(nxt) - } catch { - print(error) - } + + do { + let response = try await callable.dynamicallyCall(withKeywordArguments: toolCall.arguments) + print("tool response: \(response)") + nxt = await session.infer(message: """ + + {"id": \(toolCall.id), result: \(response)} + + """) + // TODO: If this decodes correctly, we should tail this into this method + // TODO: so that we do not decode twice + if let _ = try? JSONDecoder().decode(ToolCall.self, from: nxt!.data(using: .utf8)!) { + return await callTool(nxt!) + } + } catch { + nxt = await session.infer(message: """ + + {"id": \(toolCall.id), result: "The tool call has unfortunately failed."} + + """) + } + print(nxt ?? "nil") + } catch {} return nxt } + + public func infer(message: String) async throws -> String { + let output = await session.infer(message: message) + guard let output = await callTool(output) else { + return output + } + return output + } } public protocol LlamaActor: Actor { - static var tools: [String: (DynamicCallable, _JSONFunctionSchema)] { get } - var session: LlamaToolSession { get } + static func tools(_ self: Self) -> [String: (DynamicCallable, _JSONFunctionSchema)] + var session: LlamaToolSession! { get } } public extension LlamaActor { func chat(_ message: String) async throws -> String { - try await session.chat(message: message) + try await session.infer(message: message) } } diff --git a/swift/LlamaKitMacros/LlamaKitMacros.swift b/swift/LlamaKitMacros/LlamaKitMacros.swift index 0e67808d2..9a3a1aadb 100644 --- a/swift/LlamaKitMacros/LlamaKitMacros.swift +++ b/swift/LlamaKitMacros/LlamaKitMacros.swift @@ -16,12 +16,14 @@ struct ToolMacro: BodyMacro { struct LlamaActorMacro: ExtensionMacro, MemberMacro { static func expansion(of node: AttributeSyntax, providingMembersOf declaration: some DeclGroupSyntax, conformingTo protocols: [TypeSyntax], in context: some MacroExpansionContext) throws -> [DeclSyntax] { - [ + return [ """ - let session: LlamaToolSession + var session: LlamaToolSession! public init(params: GPTParams) async throws { - self.session = try await LlamaToolSession(params: params, tools: Self.tools) + \(raw: declaration.inheritanceClause != nil ? "self.init()" : "") + let tools = Self.tools(self) + self.session = try await LlamaToolSession(params: params, tools: tools) } """ ] @@ -41,6 +43,7 @@ struct LlamaActorMacro: ExtensionMacro, MemberMacro { callableString: String, callableName: String) ] = [] + let typeName = type.as(IdentifierTypeSyntax.self)!.name.text for member in declaration.memberBlock.members { let comments = member.leadingTrivia.filter { $0.isComment } guard let member = member.decl.as(FunctionDeclSyntax.self) else { @@ -72,6 +75,11 @@ struct LlamaActorMacro: ExtensionMacro, MemberMacro { let callableName = context.makeUniqueName(name.text) let callableString = """ @dynamicCallable struct \(callableName.text): DynamicCallable { + private weak var llamaActor: \(typeName)? + init(_ llamaActor: \(typeName)) { + self.llamaActor = llamaActor + } + @discardableResult func dynamicallyCall(withKeywordArguments args: [String: Any]) async throws -> String { \(parameters.map { @@ -79,11 +87,15 @@ struct LlamaActorMacro: ExtensionMacro, MemberMacro { }.joined(separator: "\n")) for (key, value) in args { \(parameters.map { - "if key == \"\($0.name)\" { \($0.name) = value as! \($0.type) }" + """ + if key == "\($0.name)", let value = value as? AnyDecodable, case let .\($0.type.lowercased())(v) = value { + \($0.name) = v + } + """ }.joined(separator: "\n")) } - let returnValue = try await \(name.text)(\(parameters.map { "\($0.name): \($0.name)" }.joined(separator: ","))) + let returnValue = try await self.llamaActor!.\(name.text)(\(parameters.map { "\($0.name): \($0.name)" }.joined(separator: ","))) let jsonValue = try JSONEncoder().encode(returnValue) return String(data: jsonValue, encoding: .utf8)! } @@ -105,10 +117,10 @@ struct LlamaActorMacro: ExtensionMacro, MemberMacro { $0.callableString }.joined(separator: "\n")) - static var tools: [String: (DynamicCallable, _JSONFunctionSchema)] { + static func tools(_ self: \(raw: typeName)) -> [String: (DynamicCallable, _JSONFunctionSchema)] { [\(raw: tools.map { tool in """ - "\(tool.name)": (\(tool.callableName)(), _JSONFunctionSchema(name: "\(tool.name)", description: "\(tool.description)", parameters: _JSONFunctionSchema.Parameters(properties: \(tool.parameters.count == 0 ? "[:]" : "[" + tool.parameters.map { parameter in + "\(tool.name)": (\(tool.callableName)(self), _JSONFunctionSchema(name: "\(tool.name)", description: "\(tool.description)", parameters: _JSONFunctionSchema.Parameters(properties: \(tool.parameters.count == 0 ? "[:]" : "[" + tool.parameters.map { parameter in """ "\(parameter.name)": _JSONFunctionSchema.Property(type: \(parameter.type).self, description: "\(parameter.description)"), """ diff --git a/swift/main/main.swift b/swift/main/main.swift index 716592262..745fc4b0d 100644 --- a/swift/main/main.swift +++ b/swift/main/main.swift @@ -32,7 +32,7 @@ func downloadFile() async throws -> String { @llamaActor actor MyLlama { /// Get the current date. - @Tool public static func getCurrentDate() -> String { + @Tool public func getCurrentDate() -> String { Date.now.formatted(date: .long, time: .complete) } } diff --git a/swift/test/GPTParamsTests.swift b/swift/test/GPTParamsTests.swift new file mode 100644 index 000000000..2e6587dcc --- /dev/null +++ b/swift/test/GPTParamsTests.swift @@ -0,0 +1,370 @@ +import Foundation +import XCTest + +import LlamaKit // Replace with your module name + +final class GPTParamsTests: XCTestCase { + func testPropertyAssignmentsAndCopy() throws { + // Create an instance of GPTParams + let originalParams = GPTParams() + + // Assign values to all properties + originalParams.nPredict = 10 + originalParams.nCtx = 20 + originalParams.nBatch = 30 + originalParams.nUBatch = 40 + originalParams.nKeep = 50 + originalParams.nDraft = 60 + originalParams.nChunks = 70 + originalParams.nParallel = 80 + originalParams.nSequences = 90 + originalParams.pSplit = 0.5 + originalParams.nGpuLayers = 100 + originalParams.nGpuLayersDraft = 110 + originalParams.mainGpu = 120 + originalParams.tensorSplit = [0.1, 0.2, 0.3] + originalParams.grpAttnN = 130 + originalParams.grpAttnW = 140 + originalParams.nPrint = 150 + originalParams.ropeFreqBase = 0.6 + originalParams.ropeFreqScale = 0.7 + originalParams.yarnExtFactor = 0.8 + originalParams.yarnAttnFactor = 0.9 + originalParams.yarnBetaFast = 1.0 + originalParams.yarnBetaSlow = 1.1 + originalParams.yarnOrigCtx = 160 + originalParams.defragThold = 1.2 + + // Initialize CPUParams instances if needed + originalParams.cpuParams = CPUParams() + originalParams.cpuParamsBatch = CPUParams() + originalParams.draftCpuParams = CPUParams() + originalParams.draftCpuParamsBatch = CPUParams() + + // Assign blocks and user data + originalParams.cbEval = { userData in + // Callback implementation + } + originalParams.cbEvalUserData = nil + + // Assign enum values (assuming NSInteger maps to Int) + originalParams.numaStrategy = 1 + originalParams.splitMode = 2 + originalParams.ropeScalingType = 3 + originalParams.poolingType = 4 + originalParams.attentionType = 5 + + // Assign sampler parameters + originalParams.samplerParams = GPTSamplerParams() + + // Assign string properties + originalParams.modelPath = "path/to/model" + originalParams.modelDraft = "model_draft" + originalParams.modelAlias = "alias" + originalParams.modelURL = "http://model.url" + originalParams.hfToken = "token" + originalParams.hfRepo = "repo" + originalParams.hfFile = "file" + originalParams.prompt = "prompt" + originalParams.promptFile = "prompt.txt" + originalParams.pathPromptCache = "cache/path" + originalParams.inputPrefix = "prefix" + originalParams.inputSuffix = "suffix" + originalParams.logdir = "log/dir" + originalParams.lookupCacheStatic = "static/cache" + originalParams.lookupCacheDynamic = "dynamic/cache" + originalParams.logitsFile = "logits.txt" + originalParams.rpcServers = "servers" + + // Assign array properties + originalParams.inputFiles = ["input1.txt", "input2.txt"] + originalParams.antiPrompts = ["anti1", "anti2"] + originalParams.kvOverrides = ["override1", "override2"] + originalParams.loraAdapters = ["adapter1", "adapter2"] + originalParams.controlVectors = ["control1", "control2"] + + // Assign boolean and control properties + originalParams.loraInitWithoutApply = true + originalParams.verbosity = 1 + originalParams.controlVectorLayerStart = 2 + originalParams.controlVectorLayerEnd = 3 + originalParams.pplStride = 4 + originalParams.pplOutputType = 5 + originalParams.hellaswag = true + originalParams.hellaswagTasks = 10 + originalParams.winogrande = false + originalParams.winograndeTasks = 20 + originalParams.multipleChoice = true + originalParams.multipleChoiceTasks = 30 + originalParams.klDivergence = false + originalParams.usage = true + originalParams.useColor = false + originalParams.special = true + originalParams.interactive = false + originalParams.interactiveFirst = true + originalParams.conversation = false + originalParams.promptCacheAll = true + originalParams.promptCacheRO = false + originalParams.escapeSequences = true + originalParams.multilineInput = false + originalParams.simpleIO = true + originalParams.continuousBatching = false + originalParams.flashAttention = true + originalParams.noPerformanceMetrics = false + originalParams.contextShift = true + + // Server and I/O settings + originalParams.port = 8080 + originalParams.timeoutRead = 60 + originalParams.timeoutWrite = 30 + originalParams.httpThreads = 4 + originalParams.hostname = "localhost" + originalParams.publicPath = "/public" + originalParams.chatTemplate = "template" + originalParams.systemPrompt = "system prompt" + originalParams.enableChatTemplate = true + originalParams.apiKeys = ["key1", "key2"] + originalParams.sslFileKey = "key.pem" + originalParams.sslFileCert = "cert.pem" + originalParams.endpointSlots = true + originalParams.endpointMetrics = false + originalParams.logJSON = true + originalParams.slotSavePath = "/slots" + originalParams.slotPromptSimilarity = 0.75 + + // Batched-bench params + originalParams.isPPShared = true + originalParams.nPP = [1, 2] + originalParams.nTG = [3, 4] + originalParams.nPL = [5, 6] + + // Retrieval params + originalParams.contextFiles = ["context1.txt", "context2.txt"] + originalParams.chunkSize = 1024 + originalParams.chunkSeparator = "\n" + + // Passkey params + originalParams.nJunk = 7 + originalParams.iPos = 8 + + // Imatrix params + originalParams.outFile = "output.txt" + originalParams.nOutFreq = 100 + originalParams.nSaveFreq = 200 + originalParams.iChunk = 9 + originalParams.processOutput = true + originalParams.computePPL = false + + // Cvector-generator params + originalParams.nPCABatch = 10 + originalParams.nPCAIterations = 11 + originalParams.cvectorDimreMethod = 12 + originalParams.cvectorOutfile = "cvector.out" + originalParams.cvectorPositiveFile = "positive.txt" + originalParams.cvectorNegativeFile = "negative.txt" + + // Additional properties + originalParams.spmInfill = true + originalParams.loraOutfile = "lora.out" + originalParams.embedding = false + originalParams.verbosePrompt = true + originalParams.batchedBenchOutputJSONL = false + originalParams.inputPrefixBOS = true + originalParams.ctxShift = false + originalParams.displayPrompt = true + originalParams.logging = false + + // Verify that properties are assigned correctly + XCTAssertEqual(originalParams.nPredict, 10) + XCTAssertEqual(originalParams.nCtx, 20) + XCTAssertEqual(originalParams.nBatch, 30) + XCTAssertEqual(originalParams.nUBatch, 40) + XCTAssertEqual(originalParams.nKeep, 50) + XCTAssertEqual(originalParams.nDraft, 60) + XCTAssertEqual(originalParams.nChunks, 70) + XCTAssertEqual(originalParams.nParallel, 80) + XCTAssertEqual(originalParams.nSequences, 90) + XCTAssertEqual(originalParams.pSplit, 0.5) + XCTAssertEqual(originalParams.nGpuLayers, 100) + XCTAssertEqual(originalParams.nGpuLayersDraft, 110) + XCTAssertEqual(originalParams.mainGpu, 120) + XCTAssertEqual(originalParams.tensorSplit[0..<3].map(\.floatValue), + [0.1, 0.2, 0.3]) + XCTAssertEqual(originalParams.grpAttnN, 130) + XCTAssertEqual(originalParams.grpAttnW, 140) + XCTAssertEqual(originalParams.nPrint, 150) + XCTAssertEqual(originalParams.ropeFreqBase, 0.6) + XCTAssertEqual(originalParams.ropeFreqScale, 0.7) + XCTAssertEqual(originalParams.yarnExtFactor, 0.8) + XCTAssertEqual(originalParams.yarnAttnFactor, 0.9) + XCTAssertEqual(originalParams.yarnBetaFast, 1.0) + XCTAssertEqual(originalParams.yarnBetaSlow, 1.1) + XCTAssertEqual(originalParams.yarnOrigCtx, 160) + XCTAssertEqual(originalParams.defragThold, 1.2) + + // Verify enums + XCTAssertEqual(originalParams.numaStrategy, 1) + XCTAssertEqual(originalParams.splitMode, 2) + XCTAssertEqual(originalParams.ropeScalingType, 3) + XCTAssertEqual(originalParams.poolingType, 4) + XCTAssertEqual(originalParams.attentionType, 5) + + // Verify string properties + XCTAssertEqual(originalParams.modelPath, "path/to/model") + XCTAssertEqual(originalParams.modelDraft, "model_draft") + XCTAssertEqual(originalParams.modelAlias, "alias") + XCTAssertEqual(originalParams.modelURL, "http://model.url") + XCTAssertEqual(originalParams.hfToken, "token") + XCTAssertEqual(originalParams.hfRepo, "repo") + XCTAssertEqual(originalParams.hfFile, "file") + XCTAssertEqual(originalParams.prompt, "prompt") + XCTAssertEqual(originalParams.promptFile, "prompt.txt") + XCTAssertEqual(originalParams.pathPromptCache, "cache/path") + XCTAssertEqual(originalParams.inputPrefix, "prefix") + XCTAssertEqual(originalParams.inputSuffix, "suffix") + XCTAssertEqual(originalParams.logdir, "log/dir") + XCTAssertEqual(originalParams.lookupCacheStatic, "static/cache") + XCTAssertEqual(originalParams.lookupCacheDynamic, "dynamic/cache") + XCTAssertEqual(originalParams.logitsFile, "logits.txt") + XCTAssertEqual(originalParams.rpcServers, "servers") + + // Verify array properties + XCTAssertEqual(originalParams.inputFiles, ["input1.txt", "input2.txt"]) + XCTAssertEqual(originalParams.antiPrompts, ["anti1", "anti2"]) + XCTAssertEqual(originalParams.kvOverrides as? [String], ["override1", "override2"]) +// XCTAssertEqual(originalParams.loraAdapters, ["adapter1", "adapter2"]) +// XCTAssertEqual(originalParams.controlVectors, ["control1", "control2"]) + + // Verify boolean and control properties + XCTAssertTrue(originalParams.loraInitWithoutApply) + XCTAssertEqual(originalParams.verbosity, 1) + XCTAssertEqual(originalParams.controlVectorLayerStart, 2) + XCTAssertEqual(originalParams.controlVectorLayerEnd, 3) + XCTAssertEqual(originalParams.pplStride, 4) + XCTAssertEqual(originalParams.pplOutputType, 5) + XCTAssertTrue(originalParams.hellaswag) + XCTAssertEqual(originalParams.hellaswagTasks, 10) + XCTAssertFalse(originalParams.winogrande) + XCTAssertEqual(originalParams.winograndeTasks, 20) + XCTAssertTrue(originalParams.multipleChoice) + XCTAssertEqual(originalParams.multipleChoiceTasks, 30) + XCTAssertFalse(originalParams.klDivergence) + XCTAssertTrue(originalParams.usage) + XCTAssertFalse(originalParams.useColor) + XCTAssertTrue(originalParams.special) + XCTAssertFalse(originalParams.interactive) + XCTAssertTrue(originalParams.interactiveFirst) + XCTAssertFalse(originalParams.conversation) + XCTAssertTrue(originalParams.promptCacheAll) + XCTAssertFalse(originalParams.promptCacheRO) + XCTAssertTrue(originalParams.escapeSequences) + XCTAssertFalse(originalParams.multilineInput) + XCTAssertTrue(originalParams.simpleIO) + XCTAssertFalse(originalParams.continuousBatching) + XCTAssertTrue(originalParams.flashAttention) + XCTAssertFalse(originalParams.noPerformanceMetrics) + XCTAssertTrue(originalParams.contextShift) + + // Verify server and I/O settings + XCTAssertEqual(originalParams.port, 8080) + XCTAssertEqual(originalParams.timeoutRead, 60) + XCTAssertEqual(originalParams.timeoutWrite, 30) + XCTAssertEqual(originalParams.httpThreads, 4) + XCTAssertEqual(originalParams.hostname, "localhost") + XCTAssertEqual(originalParams.publicPath, "/public") + XCTAssertEqual(originalParams.chatTemplate, "template") + XCTAssertEqual(originalParams.systemPrompt, "system prompt") + XCTAssertTrue(originalParams.enableChatTemplate) + XCTAssertEqual(originalParams.apiKeys, ["key1", "key2"]) + XCTAssertEqual(originalParams.sslFileKey, "key.pem") + XCTAssertEqual(originalParams.sslFileCert, "cert.pem") + XCTAssertTrue(originalParams.endpointSlots) + XCTAssertFalse(originalParams.endpointMetrics) + XCTAssertTrue(originalParams.logJSON) + XCTAssertEqual(originalParams.slotSavePath, "/slots") + XCTAssertEqual(originalParams.slotPromptSimilarity, 0.75) + + // Verify batched-bench params + XCTAssertTrue(originalParams.isPPShared) + XCTAssertEqual(originalParams.nPP, [1, 2]) + XCTAssertEqual(originalParams.nTG, [3, 4]) + XCTAssertEqual(originalParams.nPL, [5, 6]) + + // Verify retrieval params + XCTAssertEqual(originalParams.contextFiles, ["context1.txt", "context2.txt"]) + XCTAssertEqual(originalParams.chunkSize, 1024) + XCTAssertEqual(originalParams.chunkSeparator, "\n") + + // Verify passkey params + XCTAssertEqual(originalParams.nJunk, 7) + XCTAssertEqual(originalParams.iPos, 8) + + // Verify imatrix params + XCTAssertEqual(originalParams.outFile, "output.txt") + XCTAssertEqual(originalParams.nOutFreq, 100) + XCTAssertEqual(originalParams.nSaveFreq, 200) + XCTAssertEqual(originalParams.iChunk, 9) + XCTAssertTrue(originalParams.processOutput) + XCTAssertFalse(originalParams.computePPL) + + // Verify cvector-generator params + XCTAssertEqual(originalParams.nPCABatch, 10) + XCTAssertEqual(originalParams.nPCAIterations, 11) + XCTAssertEqual(originalParams.cvectorDimreMethod, 12) + XCTAssertEqual(originalParams.cvectorOutfile, "cvector.out") + XCTAssertEqual(originalParams.cvectorPositiveFile, "positive.txt") + XCTAssertEqual(originalParams.cvectorNegativeFile, "negative.txt") + + // Verify additional properties + XCTAssertTrue(originalParams.spmInfill) + XCTAssertEqual(originalParams.loraOutfile, "lora.out") + XCTAssertFalse(originalParams.embedding) + XCTAssertTrue(originalParams.verbosePrompt) + XCTAssertFalse(originalParams.batchedBenchOutputJSONL) + XCTAssertTrue(originalParams.inputPrefixBOS) + XCTAssertFalse(originalParams.ctxShift) + XCTAssertTrue(originalParams.displayPrompt) + XCTAssertFalse(originalParams.logging) + + // Test the copy function + guard let copiedParams = originalParams.copy() as? GPTParams else { + XCTFail("Copy function did not return a GPTParams instance.") + return + } + + // Verify that the copied properties match the original + XCTAssertEqual(copiedParams.nPredict, originalParams.nPredict) + XCTAssertEqual(copiedParams.nCtx, originalParams.nCtx) + XCTAssertEqual(copiedParams.nBatch, originalParams.nBatch) + XCTAssertEqual(copiedParams.nUBatch, originalParams.nUBatch) + XCTAssertEqual(copiedParams.nKeep, originalParams.nKeep) + XCTAssertEqual(copiedParams.nDraft, originalParams.nDraft) + XCTAssertEqual(copiedParams.nChunks, originalParams.nChunks) + XCTAssertEqual(copiedParams.nParallel, originalParams.nParallel) + XCTAssertEqual(copiedParams.nSequences, originalParams.nSequences) + XCTAssertEqual(copiedParams.pSplit, originalParams.pSplit) + XCTAssertEqual(copiedParams.nGpuLayers, originalParams.nGpuLayers) + XCTAssertEqual(copiedParams.nGpuLayersDraft, originalParams.nGpuLayersDraft) + XCTAssertEqual(copiedParams.mainGpu, originalParams.mainGpu) + XCTAssertEqual(copiedParams.tensorSplit, originalParams.tensorSplit) + XCTAssertEqual(copiedParams.grpAttnN, originalParams.grpAttnN) + XCTAssertEqual(copiedParams.grpAttnW, originalParams.grpAttnW) + XCTAssertEqual(copiedParams.nPrint, originalParams.nPrint) + XCTAssertEqual(copiedParams.ropeFreqBase, originalParams.ropeFreqBase) + XCTAssertEqual(copiedParams.ropeFreqScale, originalParams.ropeFreqScale) + XCTAssertEqual(copiedParams.yarnExtFactor, originalParams.yarnExtFactor) + XCTAssertEqual(copiedParams.yarnAttnFactor, originalParams.yarnAttnFactor) + XCTAssertEqual(copiedParams.yarnBetaFast, originalParams.yarnBetaFast) + XCTAssertEqual(copiedParams.yarnBetaSlow, originalParams.yarnBetaSlow) + XCTAssertEqual(copiedParams.yarnOrigCtx, originalParams.yarnOrigCtx) + XCTAssertEqual(copiedParams.defragThold, originalParams.defragThold) + XCTAssertEqual(copiedParams.modelPath, originalParams.modelPath) + XCTAssertEqual(copiedParams.apiKeys, originalParams.apiKeys) +// XCTAssertEqual(copiedParams.controlVectors, originalParams.controlVectors) + // Continue verifying all other properties... + + // Verify that modifying the original does not affect the copy + originalParams.nPredict = 999 + XCTAssertNotEqual(copiedParams.nPredict, originalParams.nPredict) + } +} diff --git a/swift/test/LlamaKitTests.swift b/swift/test/LlamaKitTests.swift index f021e4fd6..900f2fac3 100644 --- a/swift/test/LlamaKitTests.swift +++ b/swift/test/LlamaKitTests.swift @@ -2,27 +2,28 @@ import Foundation import Testing @testable import LlamaKit import JSONSchema +import OSLog // MARK: LlamaGrammarSession Suite -@Suite("LlamaGrammarSession Suite") -struct LlamaGrammarSessionSuite { +@Suite("LlamaSession Suite") +struct LlamaSessionSuite { @JSONSchema struct Trip { let location: String let startDate: TimeInterval let durationInDays: Int } - func downloadFile() async throws -> String { + func downloadFile(url: String, to path: String) async throws -> String { let fm = FileManager.default let tmpDir = fm.temporaryDirectory - let destinationURL = tmpDir.appending(path: "tinyllama.gguf") + let destinationURL = tmpDir.appending(path: path) guard !fm.fileExists(atPath: destinationURL.path()) else { return destinationURL.path() } - print("Downloading TinyLlama, this may take a while...") + print("Downloading \(path), this may take a while...") // Define the URL - guard let url = URL(string: "https://huggingface.co/TheBloke/TinyLlama-1.1B-Chat-v1.0-GGUF/resolve/main/tinyllama-1.1b-chat-v1.0.Q3_K_L.gguf?download=true") else { + guard let url = URL(string: url) else { print("Invalid URL.") throw URLError(.badURL) } @@ -32,22 +33,22 @@ struct LlamaGrammarSessionSuite { // Define the destination path in the documents directory - // Move the downloaded file to the destination try fm.moveItem(at: tempURL, to: destinationURL) print("File downloaded to: \(destinationURL.path())") return destinationURL.path() } - @Test func llamaGrammarSession() async throws { + + func baseParams(url: String, to path: String) async throws -> GPTParams { let params = GPTParams() - params.modelPath = try await downloadFile() - params.nPredict = 256 - params.nCtx = 1024 - params.cpuParams.nThreads = 4 - params.cpuParamsBatch.nThreads = 4 + params.modelPath = try await downloadFile(url: url, to: path) + params.nPredict = 512 + params.nCtx = 4096 + params.cpuParams.nThreads = 8 + params.cpuParamsBatch.nThreads = 8 params.nBatch = 1024 - params.nGpuLayers = 128 + params.nGpuLayers = 1024 params.chatTemplate = """ <|system|> {system_message} @@ -55,6 +56,28 @@ struct LlamaGrammarSessionSuite { {prompt} <|assistant|> """ + params.interactive = true + return params + } + + @Test func llamaInferenceSession() async throws { + let params = try await baseParams(url: "https://huggingface.co/TheBloke/TinyLlama-1.1B-Chat-v1.0-GGUF/resolve/main/tinyllama-1.1b-chat-v1.0.Q8_0.gguf?download=true", to: "tinyllama.gguf") + params.prompt = """ + <|system|> + You are an AI assistant. Answer queries simply and concisely. + """ + params.antiPrompts = [""] + params.inputPrefix = "<|user|>" + params.inputSuffix = "<|assistant|>" + params.interactive = true + let session = try await LlamaChatSession(params: params, flush: false) + for await msg in await session.inferenceStream(message: "How are you today?") { + print(msg, terminator: "") + } + } + + @Test func llamaGrammarSession() async throws { + let params = try await baseParams(url: "https://huggingface.co/TheBloke/TinyLlama-1.1B-Chat-v1.0-GGUF/resolve/main/tinyllama-1.1b-chat-v1.0.Q8_0.gguf?download=true", to: "tinyllama.gguf") params.prompt = """ You are a travel agent. The current date epoch \(Date.now.timeIntervalSince1970). Responses should have the following fields: @@ -64,7 +87,6 @@ struct LlamaGrammarSessionSuite { durationInDays: the duration of the trip in days """ - params.interactive = true let session = try await LlamaSession(params: params) await #expect(throws: Never.self) { let trip = try await session.chat(message: "Please create a trip for me to New York City that starts two weeks from now. The duration of the trip MUST be 3 days long.") @@ -73,6 +95,25 @@ struct LlamaGrammarSessionSuite { // TODO: so for now, we are just asserting the grammar works } } + + @JSONSchema struct IsCorrect { + let isSpellingCorrect: Bool + } + + @Test func llamaSimpleGrammarSession() async throws { + let params = try await baseParams(url: "https://huggingface.co/RichardErkhov/openfoodfacts_-_spellcheck-mistral-7b-gguf/resolve/main/spellcheck-mistral-7b.Q8_0.gguf?download=true", + to: "spellcheck_q8.gguf") + params.prompt = """ + ###You are a spell checker. I will provide you with the word 'strawberry'. If the spelling of the given word is correct, please respond {"isCorrect": true} else respond {"isCorrect": false}.\n + """ + let session = try await LlamaSession(params: params) + for _ in 0..<10 { + var output = try await session.chat(message: "###strawberry\n") + #expect(output.isSpellingCorrect) + output = try await session.chat(message: "###strawberrry\n") + #expect(!output.isSpellingCorrect) + } + } } import WeatherKit @@ -115,7 +156,7 @@ func downloadFile() async throws -> String { /// Get the current weather in a given location. /// - parameter location: The city and state, e.g. San Francisco, CA /// - parameter unit: The unit of temperature - @Tool public static func getCurrentWeather(location: String, unit: String) async throws -> CurrentWeather { + @Tool public func getCurrentWeather(location: String, unit: String) async throws -> CurrentWeather { let weather = try await WeatherService().weather(for: CLGeocoder().geocodeAddressString(location)[0].location!) var temperature = weather.currentWeather.temperature temperature.convert(to: .fahrenheit) @@ -134,7 +175,7 @@ func downloadFile() async throws -> String { params.nBatch = 1024 params.nGpuLayers = 1024 let llama = try await MyLlama(params: params) - let currentWeather = try await MyLlama.getCurrentWeather(location: "San Francisco, CA", unit: "farenheit") + let currentWeather = try await llama.getCurrentWeather(location: "San Francisco, CA", unit: "farenheit") let output = try await llama.chat("What's the weather (in farenheit) in San Francisco, CA?") #expect(output.contains(String(format: "%d", Int(currentWeather.temperature)))) // #expect(output.contains(currentWeather.condition.rawValue))