Clean up and add more correctness
This commit is contained in:
parent
802687a4d8
commit
b99e7f977f
25 changed files with 1095 additions and 390 deletions
|
@ -32,7 +32,7 @@ var cSettings: [CSetting] = [
|
|||
// We should consider add this in the future when we drop support for iOS 14
|
||||
// (ref: ref: https://developer.apple.com/documentation/accelerate/1513264-cblas_sgemm?language=objc)
|
||||
.define("ACCELERATE_NEW_LAPACK"),
|
||||
.define("ACCELERATE_LAPACK_ILP64")
|
||||
.define("ACCELERATE_LAPACK_ILP64"),
|
||||
]
|
||||
|
||||
#if canImport(Darwin)
|
||||
|
@ -80,7 +80,8 @@ let package = Package(
|
|||
"ggml"
|
||||
],
|
||||
sources: cppSources,
|
||||
publicHeadersPath: "spm-headers"),
|
||||
publicHeadersPath: "spm-headers",
|
||||
cSettings: cSettings),
|
||||
.target(
|
||||
name: "llama",
|
||||
dependencies: ["llama_cpp"],
|
||||
|
@ -94,6 +95,8 @@ let package = Package(
|
|||
dependencies: ["llama"],
|
||||
path: "objc",
|
||||
sources: [
|
||||
"CPUParams.mm",
|
||||
"GGMLThreadpool.mm",
|
||||
"GPTParams.mm",
|
||||
"GPTSampler.mm",
|
||||
"LlamaBatch.mm",
|
||||
|
|
68
objc/CPUParams.mm
Normal file
68
objc/CPUParams.mm
Normal file
|
@ -0,0 +1,68 @@
|
|||
#import <Foundation/Foundation.h>
|
||||
#import "CPUParams_Private.hpp"
|
||||
#import "GGMLThreadpool_Private.hpp"
|
||||
|
||||
@implementation CPUParams
|
||||
|
||||
- (instancetype)initWithParams:(cpu_params&)params
|
||||
{
|
||||
self = [super init];
|
||||
if (self) {
|
||||
self->params = ¶ms;
|
||||
}
|
||||
return self;
|
||||
}
|
||||
|
||||
- (NSInteger)nThreads {
|
||||
return params->n_threads;
|
||||
}
|
||||
|
||||
- (void)setNThreads:(NSInteger)nThreads {
|
||||
params->n_threads = nThreads;
|
||||
}
|
||||
|
||||
- (BOOL)maskValid {
|
||||
return params->mask_valid;
|
||||
}
|
||||
|
||||
- (void)setMaskValid:(BOOL)maskValid {
|
||||
params->mask_valid = maskValid;
|
||||
}
|
||||
|
||||
- (GGMLSchedPriority)priority {
|
||||
return GGMLSchedPriority(params->priority);
|
||||
}
|
||||
|
||||
- (void)setPriority:(GGMLSchedPriority)priority {
|
||||
params->priority = ggml_sched_priority(priority);
|
||||
}
|
||||
|
||||
- (BOOL)strictCPU {
|
||||
return params->strict_cpu;
|
||||
}
|
||||
|
||||
- (void)setStrictCPU:(BOOL)strictCPU {
|
||||
params->strict_cpu = strictCPU;
|
||||
}
|
||||
|
||||
- (NSUInteger)poll {
|
||||
return params->poll;
|
||||
}
|
||||
|
||||
- (void)setPoll:(NSUInteger)poll {
|
||||
params->poll = poll;
|
||||
}
|
||||
|
||||
- (BOOL)getCpuMaskAtIndex:(NSUInteger)index {
|
||||
return params->cpumask[index];
|
||||
}
|
||||
|
||||
- (void)setCpuMask:(BOOL)value atIndex:(NSUInteger)index {
|
||||
params->cpumask[index] = value;
|
||||
}
|
||||
|
||||
- (GGMLThreadpoolParams *)ggmlThreadpoolParams {
|
||||
return [[GGMLThreadpoolParams alloc] initWithParams:ggml_threadpool_params_from_cpu_params(*params)];
|
||||
}
|
||||
|
||||
@end
|
52
objc/GGMLThreadpool.mm
Normal file
52
objc/GGMLThreadpool.mm
Normal file
|
@ -0,0 +1,52 @@
|
|||
#import <Foundation/Foundation.h>
|
||||
#import "GGMLThreadpool_Private.hpp"
|
||||
|
||||
@implementation GGMLThreadpool
|
||||
|
||||
- (instancetype)initWithThreadpool:(ggml_threadpool *)threadpool
|
||||
{
|
||||
self = [super init];
|
||||
if (self) {
|
||||
self->threadpool = threadpool;
|
||||
}
|
||||
return self;
|
||||
}
|
||||
|
||||
- (ggml_threadpool *)threadpool {
|
||||
return self->threadpool;
|
||||
}
|
||||
|
||||
@end
|
||||
|
||||
@implementation GGMLThreadpoolParams {
|
||||
ggml_threadpool_params params;
|
||||
}
|
||||
|
||||
- (BOOL)getCpuMaskAtIndex:(NSUInteger)index {
|
||||
abort();
|
||||
}
|
||||
|
||||
- (void)setCpuMask:(BOOL)value atIndex:(NSUInteger)index {
|
||||
abort();
|
||||
}
|
||||
|
||||
- (instancetype)initWithParams:(ggml_threadpool_params&&)params
|
||||
{
|
||||
self = [super init];
|
||||
if (self) {
|
||||
self->params = params;
|
||||
}
|
||||
return self;
|
||||
}
|
||||
|
||||
- (BOOL)isEqual:(id)other {
|
||||
GGMLThreadpoolParams *rhs = (GGMLThreadpoolParams *)other;
|
||||
ggml_threadpool_params rhs_params = rhs->params;
|
||||
return ggml_threadpool_params_match(¶ms, &rhs_params);
|
||||
}
|
||||
|
||||
- (GGMLThreadpool *)threadpool {
|
||||
auto tp = ggml_threadpool_new(¶ms);
|
||||
return [[GGMLThreadpool alloc] initWithThreadpool:tp];
|
||||
}
|
||||
@end
|
|
@ -1,127 +1,10 @@
|
|||
#import <Foundation/Foundation.h>
|
||||
#import "GPTParams_Private.hpp"
|
||||
#import "CPUParams_Private.hpp"
|
||||
#import "GPTSampler.h"
|
||||
#import "../common/common.h"
|
||||
#import "ggml.h"
|
||||
|
||||
@implementation GGMLThreadpool {
|
||||
ggml_threadpool *threadpool;
|
||||
}
|
||||
|
||||
- (instancetype)initWithThreadpool:(ggml_threadpool *)threadpool
|
||||
{
|
||||
self = [super init];
|
||||
if (self) {
|
||||
self->threadpool = threadpool;
|
||||
}
|
||||
return self;
|
||||
}
|
||||
|
||||
- (ggml_threadpool *)threadpool {
|
||||
return threadpool;
|
||||
}
|
||||
|
||||
@end
|
||||
|
||||
@implementation GGMLThreadpoolParams {
|
||||
ggml_threadpool_params params;
|
||||
}
|
||||
|
||||
- (BOOL)getCpuMaskAtIndex:(NSUInteger)index {
|
||||
abort();
|
||||
}
|
||||
|
||||
- (void)setCpuMask:(BOOL)value atIndex:(NSUInteger)index {
|
||||
abort();
|
||||
}
|
||||
|
||||
- (instancetype)initWithParams:(ggml_threadpool_params&&)params
|
||||
{
|
||||
self = [super init];
|
||||
if (self) {
|
||||
self->params = params;
|
||||
}
|
||||
return self;
|
||||
}
|
||||
|
||||
- (BOOL)isEqual:(id)other {
|
||||
GGMLThreadpoolParams *rhs = (GGMLThreadpoolParams *)other;
|
||||
ggml_threadpool_params rhs_params = rhs->params;
|
||||
return ggml_threadpool_params_match(¶ms, &rhs_params);
|
||||
}
|
||||
|
||||
- (GGMLThreadpool *)threadpool {
|
||||
auto tp = ggml_threadpool_new(¶ms);
|
||||
return [[GGMLThreadpool alloc] initWithThreadpool:tp];
|
||||
}
|
||||
@end
|
||||
|
||||
@implementation CPUParams {
|
||||
cpu_params *params;
|
||||
}
|
||||
|
||||
- (instancetype)initWithParams:(cpu_params&)params;
|
||||
{
|
||||
self = [super init];
|
||||
if (self) {
|
||||
self->params = ¶ms;
|
||||
}
|
||||
return self;
|
||||
}
|
||||
|
||||
- (int)nThreads {
|
||||
return params->n_threads;
|
||||
}
|
||||
|
||||
- (void)setNThreads:(int)nThreads {
|
||||
params->n_threads = nThreads;
|
||||
}
|
||||
|
||||
- (BOOL)maskValid {
|
||||
return params->mask_valid;
|
||||
}
|
||||
|
||||
- (void)setMaskValid:(BOOL)maskValid {
|
||||
params->mask_valid = maskValid;
|
||||
}
|
||||
|
||||
- (GGMLSchedPriority)priority {
|
||||
return GGMLSchedPriority(params->priority);
|
||||
}
|
||||
|
||||
- (void)setPriority:(GGMLSchedPriority)priority {
|
||||
params->priority = ggml_sched_priority(priority);
|
||||
}
|
||||
|
||||
- (BOOL)strictCPU {
|
||||
return params->strict_cpu;
|
||||
}
|
||||
|
||||
- (void)setStrictCPU:(BOOL)strictCPU {
|
||||
params->strict_cpu = strictCPU;
|
||||
}
|
||||
|
||||
- (uint32_t)poll {
|
||||
return params->poll;
|
||||
}
|
||||
|
||||
- (void)setPoll:(uint32_t)poll {
|
||||
params->poll = poll;
|
||||
}
|
||||
|
||||
- (BOOL)getCpuMaskAtIndex:(NSUInteger)index {
|
||||
return params->cpumask[index];
|
||||
}
|
||||
|
||||
- (void)setCpuMask:(BOOL)value atIndex:(NSUInteger)index {
|
||||
params->cpumask[index] = value;
|
||||
}
|
||||
|
||||
- (GGMLThreadpoolParams *)ggmlThreadpoolParams {
|
||||
return [[GGMLThreadpoolParams alloc] initWithParams:ggml_threadpool_params_from_cpu_params(*params)];
|
||||
}
|
||||
|
||||
@end
|
||||
|
||||
@implementation GPTSamplerParams {
|
||||
common_sampler_params *gpt_sampler_params;
|
||||
}
|
||||
|
@ -415,6 +298,42 @@
|
|||
return antiprompts;
|
||||
}
|
||||
|
||||
- (void)setAntiPrompts:(NSArray<NSString *> *)antiPrompts {
|
||||
gpt_params.antiprompt.clear();
|
||||
for (NSString *antiprompt in antiPrompts) {
|
||||
gpt_params.antiprompt.push_back([antiprompt cStringUsingEncoding:NSUTF8StringEncoding]);
|
||||
}
|
||||
}
|
||||
|
||||
- (NSArray<NSString *> *)apiKeys {
|
||||
auto apiKeys = [[NSMutableArray alloc] init];
|
||||
for (auto& apiKey : gpt_params.api_keys) {
|
||||
[apiKeys addObject:[NSString stringWithCString:apiKey.c_str() encoding:NSUTF8StringEncoding]];
|
||||
}
|
||||
return apiKeys;
|
||||
}
|
||||
|
||||
- (void)setApiKeys:(NSArray<NSString *> *)apiKeys {
|
||||
gpt_params.api_keys.clear();
|
||||
for (NSString *apiKey in apiKeys) {
|
||||
gpt_params.api_keys.push_back([apiKey cStringUsingEncoding:NSUTF8StringEncoding]);
|
||||
}
|
||||
}
|
||||
|
||||
- (NSArray<NSNumber *> *)tensorSplit {
|
||||
auto tensorSplit = [[NSMutableArray alloc] init];
|
||||
for (auto& tensor : gpt_params.tensor_split) {
|
||||
[tensorSplit addObject:[[NSNumber alloc] initWithFloat:tensor]];
|
||||
}
|
||||
return tensorSplit;
|
||||
}
|
||||
|
||||
- (void)setTensorSplit:(NSArray<NSNumber *> *)tensorSplit {
|
||||
for (size_t i = 0; i < [tensorSplit count]; i++) {
|
||||
gpt_params.tensor_split[i] = [tensorSplit[i] floatValue];
|
||||
}
|
||||
}
|
||||
|
||||
- (common_params&)params {
|
||||
return gpt_params;
|
||||
}
|
||||
|
@ -716,4 +635,29 @@
|
|||
gpt_params.input_suffix = [inputSuffix cStringUsingEncoding:NSUTF8StringEncoding];
|
||||
}
|
||||
|
||||
- (BOOL)interactive {
|
||||
return gpt_params.interactive;
|
||||
}
|
||||
|
||||
- (void)setInteractive:(BOOL)interactive {
|
||||
gpt_params.interactive = interactive;
|
||||
}
|
||||
|
||||
- (BOOL)interactiveFirst {
|
||||
return gpt_params.interactive_first;
|
||||
}
|
||||
|
||||
- (void)setInteractiveFirst:(BOOL)interactiveFirst {
|
||||
gpt_params.interactive_first = interactiveFirst;
|
||||
}
|
||||
- (id)copyWithZone:(NSZone *)zone {
|
||||
GPTParams *copy = [[[self class] allocWithZone:zone] init];
|
||||
|
||||
if (copy) {
|
||||
copy->gpt_params = gpt_params;
|
||||
}
|
||||
|
||||
return copy;
|
||||
}
|
||||
|
||||
@end
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
#import <Foundation/Foundation.h>
|
||||
#import "LlamaContext_Private.hpp"
|
||||
#import "GGMLThreadpool_Private.hpp"
|
||||
#import "GPTParams_Private.hpp"
|
||||
#import "LlamaModel_Private.hpp"
|
||||
#import "LlamaBatch_Private.hpp"
|
||||
|
@ -17,6 +18,12 @@
|
|||
return self;
|
||||
}
|
||||
|
||||
- (void)dealloc
|
||||
{
|
||||
[super dealloc];
|
||||
llama_free(ctx);
|
||||
}
|
||||
|
||||
- (void)attachThreadpool:(GGMLThreadpool *)threadpool
|
||||
threadpoolBatch:(GGMLThreadpool *)threadpoolBatch {
|
||||
llama_attach_threadpool(ctx, [threadpool threadpool], [threadpoolBatch threadpool]);
|
||||
|
|
|
@ -22,6 +22,12 @@
|
|||
return self;
|
||||
}
|
||||
|
||||
- (void)dealloc
|
||||
{
|
||||
[super dealloc];
|
||||
llama_free_model(model);
|
||||
}
|
||||
|
||||
- (LlamaContext *)context:(LlamaContextParams *)params {
|
||||
return nil;
|
||||
}
|
||||
|
|
|
@ -3,10 +3,12 @@
|
|||
#import "../../common/common.h"
|
||||
#import "LlamaModel_Private.hpp"
|
||||
#import "LlamaContext_Private.hpp"
|
||||
#import "CPUParams_Private.hpp"
|
||||
#import "GPTSampler.h"
|
||||
#import <OSLog/OSLog.h>
|
||||
#import "ggml.h"
|
||||
#import "GPTParams_Private.hpp"
|
||||
#import "GGMLThreadpool_Private.hpp"
|
||||
#import "LlamaBatch_Private.hpp"
|
||||
|
||||
@implementation BlockingLineQueue {
|
||||
|
@ -75,8 +77,7 @@
|
|||
@implementation LlamaSession {
|
||||
std::vector<llama_token> embd_inp;
|
||||
std::vector<common_chat_msg> chat_msgs;
|
||||
GPTParams *params;
|
||||
GPTSampler *smpl;
|
||||
|
||||
BOOL isInteracting;
|
||||
|
||||
bool is_antiprompt;
|
||||
|
@ -105,13 +106,14 @@
|
|||
std::vector<std::vector<llama_token>> antiprompt_ids;
|
||||
BOOL need_insert_eot;
|
||||
int n_ctx;
|
||||
os_log_t os_log_inst;
|
||||
}
|
||||
|
||||
- (NSString *)chat_add_and_format:(std::vector<common_chat_msg> &) chat_msgs role:(const std::string &) role content:(const std::string &) content {
|
||||
common_chat_msg new_msg{role, content};
|
||||
auto formatted = common_chat_format_single([self.model cModel], [params params].chat_template, chat_msgs, new_msg, role == "user");
|
||||
auto formatted = common_chat_format_single([self.model cModel], [_params params].chat_template, chat_msgs, new_msg, role == "user");
|
||||
chat_msgs.push_back({role, content});
|
||||
os_log_debug(OS_LOG_DEFAULT, "formatted: '%s'\n", formatted.c_str());
|
||||
os_log_debug(os_log_inst, "formatted: '%s'\n", formatted.c_str());
|
||||
return [NSString stringWithCString:formatted.c_str() encoding:NSUTF8StringEncoding];
|
||||
}
|
||||
|
||||
|
@ -131,22 +133,24 @@ static BOOL file_is_empty(NSString *path) {
|
|||
|
||||
- (instancetype)initWithParams:(GPTParams *)params {
|
||||
self = [super init];
|
||||
|
||||
self->params = params;
|
||||
// model = llama_init.model;
|
||||
// ctx = llama_init.context;
|
||||
//
|
||||
// if model == nil {
|
||||
// LOG_ERR("%s: error: unable to load model\n", __func__);
|
||||
// return 1;
|
||||
// }
|
||||
//
|
||||
os_log_info(OS_LOG_DEFAULT,
|
||||
"%s: llama threadpool init, n_threads = %d\n",
|
||||
__func__, params.cpuParams.nThreads);
|
||||
self->_params = [params copy];
|
||||
self->_mutableLastOutput = [[NSMutableString alloc] init];
|
||||
if (params.logging) {
|
||||
os_log_inst = OS_LOG_DEFAULT;
|
||||
} else {
|
||||
os_log_inst = OS_LOG_DISABLED;
|
||||
}
|
||||
if (!params.modelPath) {
|
||||
[NSException raise:@"ModelFailure"
|
||||
format:@"params.modelPath must be defined"];
|
||||
}
|
||||
|
||||
os_log_info(os_log_inst,
|
||||
"%s: llama threadpool init, n_threads = %ld\n",
|
||||
__func__, static_cast<long>(params.cpuParams.nThreads));
|
||||
|
||||
if (params.embedding) {
|
||||
os_log_error(OS_LOG_DEFAULT,
|
||||
os_log_error(os_log_inst,
|
||||
R"(************
|
||||
please use the 'embedding' tool for embedding calculations
|
||||
************)");
|
||||
|
@ -154,22 +158,29 @@ static BOOL file_is_empty(NSString *path) {
|
|||
}
|
||||
|
||||
if (params.nCtx != 0 && params.nCtx < 8) {
|
||||
os_log_info(OS_LOG_DEFAULT, "minimum context size is 8, using minimum size.");
|
||||
os_log_info(os_log_inst, "minimum context size is 8, using minimum size.");
|
||||
params.nCtx = 8;
|
||||
}
|
||||
|
||||
if (params.ropeFreqBase != 0) {
|
||||
os_log_info(OS_LOG_DEFAULT, "changing RoPE frequency base to \(params.ropeFreqBase)");
|
||||
os_log_info(os_log_inst, "changing RoPE frequency base to \(params.ropeFreqBase)");
|
||||
}
|
||||
|
||||
if (params.ropeFreqScale != 0.0) {
|
||||
os_log_info(OS_LOG_DEFAULT, "scaling RoPE frequency by \(params.ropeFreqScale)");
|
||||
os_log_info(os_log_inst, "scaling RoPE frequency by \(params.ropeFreqScale)");
|
||||
}
|
||||
|
||||
llama_backend_init();
|
||||
llama_numa_init(ggml_numa_strategy(params.numaStrategy));
|
||||
auto llama_init = common_init_from_params([params params]);
|
||||
|
||||
if (llama_init.context == nil) {
|
||||
[NSException raise:@"ContextFailure"
|
||||
format:@"could not load context"];
|
||||
}
|
||||
if (llama_init.model == nil) {
|
||||
[NSException raise:@"ModelLoadFailure"
|
||||
format:@"could not load model"];
|
||||
}
|
||||
auto tpp_batch = params.cpuParamsBatch.ggmlThreadpoolParams;
|
||||
auto tpp = params.cpuParams.ggmlThreadpoolParams;
|
||||
|
||||
|
@ -179,7 +190,7 @@ static BOOL file_is_empty(NSString *path) {
|
|||
if (tpp != tpp_batch) {
|
||||
threadpool_batch = [tpp_batch threadpool];
|
||||
if (!threadpool_batch) {
|
||||
[NSException raise:@"batch threadpool create failed"
|
||||
[NSException raise:@"ThreadpoolFailure"
|
||||
format:@"batch threadpool create failed"];
|
||||
}
|
||||
|
||||
|
@ -189,7 +200,7 @@ static BOOL file_is_empty(NSString *path) {
|
|||
|
||||
GGMLThreadpool *threadpool = [tpp threadpool];
|
||||
if (!threadpool) {
|
||||
[NSException raise:@"threadpool create failed"
|
||||
[NSException raise:@"ThreadpoolFailure"
|
||||
format:@"threadpool create failed"];
|
||||
}
|
||||
|
||||
|
@ -200,16 +211,16 @@ static BOOL file_is_empty(NSString *path) {
|
|||
n_ctx = [self.ctx nCtx];
|
||||
//
|
||||
if (n_ctx > n_ctx_train) {
|
||||
os_log_info(OS_LOG_DEFAULT, "%s: model was trained on only %d context tokens (%d specified)\n", __func__, n_ctx_train, n_ctx);
|
||||
os_log_info(os_log_inst, "%s: model was trained on only %d context tokens (%d specified)\n", __func__, n_ctx_train, n_ctx);
|
||||
}
|
||||
|
||||
// print chat template example in conversation mode
|
||||
if (params.conversation) {
|
||||
if (params.enableChatTemplate) {
|
||||
os_log_info(OS_LOG_DEFAULT, "%s: chat template example:\n%s\n", __func__,
|
||||
os_log_info(os_log_inst, "%s: chat template example:\n%s\n", __func__,
|
||||
[[self.model formatExample:params.chatTemplate] cStringUsingEncoding:NSUTF8StringEncoding]);
|
||||
} else {
|
||||
os_log_info(OS_LOG_DEFAULT, "%s: in-suffix/prefix is specified, chat template will be disabled\n", __func__);
|
||||
os_log_info(os_log_inst, "%s: in-suffix/prefix is specified, chat template will be disabled\n", __func__);
|
||||
}
|
||||
}
|
||||
// print system information
|
||||
|
@ -222,11 +233,11 @@ static BOOL file_is_empty(NSString *path) {
|
|||
NSFileManager *fileManager = [NSFileManager defaultManager];
|
||||
|
||||
if ([pathSession length] != 0) {
|
||||
os_log_info(OS_LOG_DEFAULT, "%s: attempting to load saved session from '%s'\n", __func__, [pathSession cStringUsingEncoding:NSUTF8StringEncoding]);
|
||||
os_log_info(os_log_inst, "%s: attempting to load saved session from '%s'\n", __func__, [pathSession cStringUsingEncoding:NSUTF8StringEncoding]);
|
||||
if (![fileManager fileExistsAtPath:pathSession]) {
|
||||
os_log_info(OS_LOG_DEFAULT, "%s: session file does not exist, will create.\n", __func__);
|
||||
os_log_info(os_log_inst, "%s: session file does not exist, will create.\n", __func__);
|
||||
} else if (file_is_empty(pathSession)) {
|
||||
os_log_info(OS_LOG_DEFAULT,"%s: The session file is empty. A new session will be initialized.\n", __func__);
|
||||
os_log_info(os_log_inst,"%s: The session file is empty. A new session will be initialized.\n", __func__);
|
||||
} else {
|
||||
// The file exists and is not empty
|
||||
session_tokens.resize(n_ctx);
|
||||
|
@ -235,7 +246,7 @@ static BOOL file_is_empty(NSString *path) {
|
|||
[NSException raise:@"SessionLoadFailure" format:@"%s: failed to load session file '%s'\n", __func__, [pathSession cStringUsingEncoding:NSUTF8StringEncoding]];
|
||||
}
|
||||
session_tokens.resize(n_token_count_out);
|
||||
os_log_info(OS_LOG_DEFAULT,"%s: loaded a session with prompt size of %d tokens\n", __func__, (int)session_tokens.size());
|
||||
os_log_info(os_log_inst,"%s: loaded a session with prompt size of %d tokens\n", __func__, (int)session_tokens.size());
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -244,7 +255,7 @@ static BOOL file_is_empty(NSString *path) {
|
|||
GGML_ASSERT(![self.model addEOSToken]);
|
||||
}
|
||||
|
||||
os_log_debug(OS_LOG_DEFAULT, "n_ctx: %d, add_bos: %d\n", n_ctx, addBOS);
|
||||
os_log_debug(os_log_inst, "n_ctx: %d, add_bos: %d\n", n_ctx, addBOS);
|
||||
|
||||
|
||||
{
|
||||
|
@ -252,22 +263,22 @@ static BOOL file_is_empty(NSString *path) {
|
|||
? [self chat_add_and_format:chat_msgs role:"system" content:[params params].prompt] // format the system prompt in conversation mode
|
||||
: params.prompt;
|
||||
if (params.interactiveFirst || [params.prompt length] > 0 || session_tokens.empty()) {
|
||||
os_log_debug(OS_LOG_DEFAULT, "tokenize the prompt\n");
|
||||
os_log_debug(os_log_inst, "tokenize the prompt\n");
|
||||
embd_inp = [self.ctx tokenize:prompt addSpecial:true parseSpecial:true];
|
||||
} else {
|
||||
os_log_debug(OS_LOG_DEFAULT,"use session tokens\n");
|
||||
os_log_debug(os_log_inst,"use session tokens\n");
|
||||
embd_inp = session_tokens;
|
||||
}
|
||||
|
||||
os_log_debug(OS_LOG_DEFAULT,"prompt: \"%s\"\n", [prompt cStringUsingEncoding:NSUTF8StringEncoding]);
|
||||
os_log_debug(OS_LOG_DEFAULT,"tokens: %s\n", [self.ctx convertTokensToString:embd_inp].c_str());
|
||||
os_log_debug(os_log_inst,"prompt: \"%s\"\n", [prompt cStringUsingEncoding:NSUTF8StringEncoding]);
|
||||
os_log_debug(os_log_inst,"tokens: %s\n", [self.ctx convertTokensToString:embd_inp].c_str());
|
||||
}
|
||||
|
||||
// Should not run without any tokens
|
||||
if (embd_inp.empty()) {
|
||||
if (addBOS) {
|
||||
embd_inp.push_back([self.model tokenBOS]);
|
||||
// LOG_WRN("embd_inp was considered empty and bos was added: %s\n", string_from(ctx, embd_inp).c_str());
|
||||
os_log_info(os_log_inst, "embd_inp was considered empty and bos was added: %s\n", [_ctx convertTokensToString:embd_inp].c_str());
|
||||
} else {
|
||||
[NSException raise:@"InputEmptyError" format:@"input is empty"];
|
||||
}
|
||||
|
@ -303,13 +314,13 @@ static BOOL file_is_empty(NSString *path) {
|
|||
llama_kv_cache_seq_rm([self.ctx cContext], -1, n_matching_session_tokens, -1);
|
||||
}
|
||||
//
|
||||
// os_log_debug(OS_LOG_DEFAULT, "recalculate the cached logits (check): embd_inp.size() %zu, n_matching_session_tokens %zu, embd_inp.size() %zu, session_tokens.size() %zu\n",
|
||||
// os_log_debug(os_log_inst, "recalculate the cached logits (check): embd_inp.size() %zu, n_matching_session_tokens %zu, embd_inp.size() %zu, session_tokens.size() %zu\n",
|
||||
// embd_inp.size(), n_matching_session_tokens, embd_inp.size(), session_tokens.size());
|
||||
//
|
||||
// if we will use the cache for the full prompt without reaching the end of the cache, force
|
||||
// reevaluation of the last token to recalculate the cached logits
|
||||
if (!embd_inp.empty() && n_matching_session_tokens == embd_inp.size() && session_tokens.size() > embd_inp.size()) {
|
||||
// os_log_debug(OS_LOG_DEFAULT, "recalculate the cached logits (do): session_tokens.resize( %zu )\n", embd_inp.size() - 1);
|
||||
// os_log_debug(os_log_inst, "recalculate the cached logits (do): session_tokens.resize( %zu )\n", embd_inp.size() - 1);
|
||||
|
||||
session_tokens.resize(embd_inp.size() - 1);
|
||||
}
|
||||
|
@ -331,22 +342,21 @@ static BOOL file_is_empty(NSString *path) {
|
|||
}
|
||||
|
||||
if (params.verbosePrompt) {
|
||||
// LOG_INF("%s: prompt: '%s'\n", __func__, params.prompt.c_str());
|
||||
os_log_info(os_log_inst,
|
||||
"%s: prompt: '%s'\n", __func__, [params.prompt cStringUsingEncoding:NSUTF8StringEncoding]);
|
||||
// LOG_INF("%s: number of tokens in prompt = %zu\n", __func__, embd_inp.size());
|
||||
for (int i = 0; i < (int) embd_inp.size(); i++) {
|
||||
os_log_info(OS_LOG_DEFAULT, "%6d -> '%s'\n", embd_inp[i],
|
||||
os_log_info(os_log_inst, "%6d -> '%s'\n", embd_inp[i],
|
||||
[[self.ctx tokenToPiece:embd_inp[i]] cStringUsingEncoding:NSUTF8StringEncoding]);
|
||||
}
|
||||
|
||||
if (params.nKeep > addBOS) {
|
||||
// LOG_INF("%s: static prompt based on n_keep: '", __func__);
|
||||
for (int i = 0; i < params.nKeep; i++) {
|
||||
os_log_debug(OS_LOG_DEFAULT, "%s",
|
||||
os_log_debug(os_log_inst, "%s",
|
||||
[[self.ctx tokenToPiece:embd_inp[i]] cStringUsingEncoding:NSUTF8StringEncoding]);
|
||||
}
|
||||
// LOG("'\n");
|
||||
}
|
||||
// LOG_INF("\n");
|
||||
}
|
||||
//
|
||||
// // ctrl+C handling
|
||||
|
@ -366,55 +376,55 @@ static BOOL file_is_empty(NSString *path) {
|
|||
// }
|
||||
//
|
||||
if (params.interactive) {
|
||||
os_log_info(OS_LOG_DEFAULT, "%s: interactive mode on.\n", __func__);
|
||||
os_log_info(os_log_inst, "%s: interactive mode on.\n", __func__);
|
||||
|
||||
if ([params.antiPrompts count] > 0) {
|
||||
for (NSString *antiprompt in params.antiPrompts) {
|
||||
os_log_info(OS_LOG_DEFAULT, "Reverse prompt: '%s'\n", [antiprompt cStringUsingEncoding:NSUTF8StringEncoding]);
|
||||
os_log_info(os_log_inst, "Reverse prompt: '%s'\n", [antiprompt cStringUsingEncoding:NSUTF8StringEncoding]);
|
||||
if (params.verbosePrompt) {
|
||||
auto tmp = [_ctx tokenize:antiprompt
|
||||
addSpecial:false
|
||||
parseSpecial:true];
|
||||
for (int i = 0; i < (int) tmp.size(); i++) {
|
||||
os_log_info(OS_LOG_DEFAULT, "%6d -> '%s'\n", tmp[i], [[self.ctx tokenToPiece:tmp[i]] cStringUsingEncoding:NSUTF8StringEncoding]);
|
||||
os_log_info(os_log_inst, "%6d -> '%s'\n", tmp[i], [[self.ctx tokenToPiece:tmp[i]] cStringUsingEncoding:NSUTF8StringEncoding]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (params.inputPrefixBOS) {
|
||||
os_log_info(OS_LOG_DEFAULT, "Input prefix with BOS\n");
|
||||
os_log_info(os_log_inst, "Input prefix with BOS\n");
|
||||
}
|
||||
|
||||
if ([params.inputPrefix length] > 0) {
|
||||
os_log_info(OS_LOG_DEFAULT, "Input prefix: '%s'\n", [params.inputPrefix cStringUsingEncoding:NSUTF8StringEncoding]);
|
||||
os_log_info(os_log_inst, "Input prefix: '%s'\n", [params.inputPrefix cStringUsingEncoding:NSUTF8StringEncoding]);
|
||||
if (params.verbosePrompt) {
|
||||
auto tmp = [_ctx tokenize:params.inputPrefix addSpecial:true parseSpecial:true];
|
||||
for (int i = 0; i < (int) tmp.size(); i++) {
|
||||
os_log_info(OS_LOG_DEFAULT, "%6d -> '%s'\n",
|
||||
os_log_info(os_log_inst, "%6d -> '%s'\n",
|
||||
tmp[i], [[self.ctx tokenToPiece:tmp[i]] cStringUsingEncoding:NSUTF8StringEncoding]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if ([params.inputSuffix length] > 0) {
|
||||
os_log_info(OS_LOG_DEFAULT, "Input suffix: '%s'\n", [params.inputSuffix cStringUsingEncoding:NSUTF8StringEncoding]);
|
||||
os_log_info(os_log_inst, "Input suffix: '%s'\n", [params.inputSuffix cStringUsingEncoding:NSUTF8StringEncoding]);
|
||||
if (params.verbosePrompt) {
|
||||
auto tmp = [_ctx tokenize:params.inputSuffix addSpecial:false parseSpecial:true];
|
||||
for (int i = 0; i < (int) tmp.size(); i++) {
|
||||
os_log_info(OS_LOG_DEFAULT, "%6d -> '%s'\n",
|
||||
os_log_info(os_log_inst, "%6d -> '%s'\n",
|
||||
tmp[i], [[self.ctx tokenToPiece:tmp[i]] cStringUsingEncoding:NSUTF8StringEncoding]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
smpl = [[GPTSampler alloc] init:_model gptSamplerParams:[params samplerParams]];
|
||||
if (!smpl) {
|
||||
_smpl = [[GPTSampler alloc] init:_model gptSamplerParams:[params samplerParams]];
|
||||
if (!_smpl) {
|
||||
[NSException raise:@"SamplingFailure" format:@"failed to initialize sampling subsystem"];
|
||||
}
|
||||
|
||||
os_log_info(OS_LOG_DEFAULT, "sampler seed: %u\n", [smpl seed]);
|
||||
os_log_info(os_log_inst, "sampler seed: %u\n", [_smpl seed]);
|
||||
// LOG_INF("sampler params: \n%s\n", sparams.print().c_str());
|
||||
// LOG_INF("sampler chain: %s\n", gpt_sampler_print(smpl).c_str());
|
||||
//
|
||||
|
@ -431,7 +441,7 @@ static BOOL file_is_empty(NSString *path) {
|
|||
GGML_ASSERT(ga_w % ga_n == 0 && "grp_attn_w must be a multiple of grp_attn_n"); // NOLINT
|
||||
//GGML_ASSERT(n_ctx_train % ga_w == 0 && "n_ctx_train must be a multiple of grp_attn_w"); // NOLINT
|
||||
//GGML_ASSERT(n_ctx >= n_ctx_train * ga_n && "n_ctx must be at least n_ctx_train * grp_attn_n"); // NOLINT
|
||||
os_log_info(OS_LOG_DEFAULT, "self-extend: n_ctx_train = %d, grp_attn_n = %ld, grp_attn_w = %ld\n", n_ctx_train, static_cast<long>(ga_n), static_cast<long>(ga_w));
|
||||
os_log_info(os_log_inst, "self-extend: n_ctx_train = %d, grp_attn_n = %ld, grp_attn_w = %ld\n", n_ctx_train, static_cast<long>(ga_n), static_cast<long>(ga_w));
|
||||
}
|
||||
|
||||
if (params.interactive) {
|
||||
|
@ -454,14 +464,6 @@ static BOOL file_is_empty(NSString *path) {
|
|||
need_to_save_session = [pathSession length] > 0 && n_matching_session_tokens < embd_inp.size();
|
||||
n_remain = params.nPredict;
|
||||
|
||||
// // the first thing we will do is to output the prompt, so set color accordingly
|
||||
// console::set_display(console::prompt);
|
||||
// display = params.display_prompt;
|
||||
//
|
||||
|
||||
|
||||
|
||||
|
||||
antiprompt_ids.reserve([params.antiPrompts count]);
|
||||
for (NSString *antiprompt in params.antiPrompts) {
|
||||
antiprompt_ids.emplace_back([self.ctx tokenize:antiprompt addSpecial:false parseSpecial:true]);
|
||||
|
@ -486,8 +488,13 @@ static BOOL file_is_empty(NSString *path) {
|
|||
return self;
|
||||
}
|
||||
|
||||
// MARK: LastOutput
|
||||
- (NSString *)lastOutput {
|
||||
return [_mutableLastOutput copy];
|
||||
}
|
||||
|
||||
- (void)start:(BlockingLineQueue *)queue {
|
||||
while ((n_remain != 0 && !is_antiprompt) || params.interactive) {
|
||||
while ((n_remain != 0 && !is_antiprompt) || _params.interactive) {
|
||||
// predict
|
||||
if (!embd.empty()) {
|
||||
// Note: (n_ctx - 4) here is to match the logic for commandline prompt handling via
|
||||
|
@ -500,42 +507,42 @@ static BOOL file_is_empty(NSString *path) {
|
|||
embd.resize(max_embd_size);
|
||||
|
||||
// console::set_display(console::error);
|
||||
os_log_error(OS_LOG_DEFAULT, "<<input too long: skipped %d token%s>>", skipped_tokens, skipped_tokens != 1 ? "s" : "");
|
||||
os_log_error(os_log_inst, "<<input too long: skipped %d token%s>>", skipped_tokens, skipped_tokens != 1 ? "s" : "");
|
||||
// console::set_display(console::reset);
|
||||
}
|
||||
|
||||
if (params.grpAttnN == 1) {
|
||||
if (_params.grpAttnN == 1) {
|
||||
// infinite text generation via context shifting
|
||||
// if we run out of context:
|
||||
// - take the n_keep first tokens from the original prompt (via n_past)
|
||||
// - take half of the last (n_ctx - n_keep) tokens and recompute the logits in batches
|
||||
|
||||
if (n_past + (int) embd.size() >= [_ctx nCtx]) {
|
||||
if (!params.ctxShift) {
|
||||
os_log_debug(OS_LOG_DEFAULT, "\n\n%s: context full and context shift is disabled => stopping\n", __func__);
|
||||
if (!_params.ctxShift) {
|
||||
os_log_debug(os_log_inst, "\n\n%s: context full and context shift is disabled => stopping\n", __func__);
|
||||
break;
|
||||
} else {
|
||||
if (params.nPredict == -2) {
|
||||
os_log_debug(OS_LOG_DEFAULT, "\n\n%s: context full and n_predict == -%d => stopping\n", __func__, params.nPredict);
|
||||
if (_params.nPredict == -2) {
|
||||
os_log_debug(os_log_inst, "\n\n%s: context full and n_predict == -%d => stopping\n", __func__, _params.nPredict);
|
||||
break;
|
||||
}
|
||||
|
||||
const int n_left = n_past - params.nKeep;
|
||||
const int n_left = n_past - _params.nKeep;
|
||||
const int n_discard = n_left/2;
|
||||
|
||||
os_log_debug(OS_LOG_DEFAULT, "context full, swapping: n_past = %d, n_left = %d, n_ctx = %lu, n_keep = %d, n_discard = %d\n",
|
||||
n_past, n_left, static_cast<unsigned long>([_ctx nCtx]), params.nKeep, n_discard);
|
||||
os_log_debug(os_log_inst, "context full, swapping: n_past = %d, n_left = %d, n_ctx = %lu, n_keep = %d, n_discard = %d\n",
|
||||
n_past, n_left, static_cast<unsigned long>([_ctx nCtx]), _params.nKeep, n_discard);
|
||||
|
||||
llama_kv_cache_seq_rm ([self.ctx cContext], 0, params.nKeep , params.nKeep + n_discard);
|
||||
llama_kv_cache_seq_add([self.ctx cContext], 0, params.nKeep + n_discard, n_past, -n_discard);
|
||||
llama_kv_cache_seq_rm ([self.ctx cContext], 0, _params.nKeep , _params.nKeep + n_discard);
|
||||
llama_kv_cache_seq_add([self.ctx cContext], 0, _params.nKeep + n_discard, n_past, -n_discard);
|
||||
|
||||
n_past -= n_discard;
|
||||
|
||||
os_log_debug(OS_LOG_DEFAULT, "after swap: n_past = %d\n", n_past);
|
||||
os_log_debug(os_log_inst, "after swap: n_past = %d\n", n_past);
|
||||
|
||||
os_log_debug(OS_LOG_DEFAULT, "embd: %s\n", [self.ctx convertTokensToString:embd].c_str());
|
||||
os_log_debug(os_log_inst, "embd: %s\n", [self.ctx convertTokensToString:embd].c_str());
|
||||
|
||||
os_log_debug(OS_LOG_DEFAULT, "clear session path\n");
|
||||
os_log_debug(os_log_inst, "clear session path\n");
|
||||
[pathSession setString:@""];
|
||||
}
|
||||
}
|
||||
|
@ -546,10 +553,10 @@ static BOOL file_is_empty(NSString *path) {
|
|||
const int bd = (ga_w/ga_n)*(ga_n - 1);
|
||||
const int dd = (ga_w/ga_n) - ib*bd - ga_w;
|
||||
|
||||
os_log_debug(OS_LOG_DEFAULT, "\n");
|
||||
os_log_debug(OS_LOG_DEFAULT, "shift: [%6ld, %6d] + %6d -> [%6ld, %6d]\n", static_cast<long>(ga_i), n_past, ib*bd, static_cast<long>(ga_i + ib*bd), n_past + ib*bd);
|
||||
os_log_debug(OS_LOG_DEFAULT, "div: [%6ld, %6ld] / %6ld -> [%6ld, %6ld]\n", static_cast<long>(ga_i + ib*bd), static_cast<long>(ga_i + ib*bd + ga_w), static_cast<long>(ga_n), static_cast<long>((ga_i + ib*bd)/ga_n), static_cast<long>((ga_i + ib*bd + ga_w)/ga_n));
|
||||
os_log_debug(OS_LOG_DEFAULT, "shift: [%6ld, %6d] + %6d -> [%6ld, %6d]\n", static_cast<long>(ga_i + ib*bd + ga_w), n_past + ib*bd, dd, static_cast<long>(ga_i + ib*bd + ga_w + dd), n_past + ib*bd + dd);
|
||||
os_log_debug(os_log_inst, "\n");
|
||||
os_log_debug(os_log_inst, "shift: [%6ld, %6d] + %6d -> [%6ld, %6d]\n", static_cast<long>(ga_i), n_past, ib*bd, static_cast<long>(ga_i + ib*bd), n_past + ib*bd);
|
||||
os_log_debug(os_log_inst, "div: [%6ld, %6ld] / %6ld -> [%6ld, %6ld]\n", static_cast<long>(ga_i + ib*bd), static_cast<long>(ga_i + ib*bd + ga_w), static_cast<long>(ga_n), static_cast<long>((ga_i + ib*bd)/ga_n), static_cast<long>((ga_i + ib*bd + ga_w)/ga_n));
|
||||
os_log_debug(os_log_inst, "shift: [%6ld, %6d] + %6d -> [%6ld, %6d]\n", static_cast<long>(ga_i + ib*bd + ga_w), n_past + ib*bd, dd, static_cast<long>(ga_i + ib*bd + ga_w + dd), n_past + ib*bd + dd);
|
||||
|
||||
[self.ctx kvCacheSeqAdd:0 p0:ga_i p1:n_past delta:ib*bd];
|
||||
[self.ctx kvCacheSeqDiv:0 p0:ga_i + ib*bd p1:ga_i + ib*bd + ga_w delta:ga_n];
|
||||
|
@ -559,7 +566,7 @@ static BOOL file_is_empty(NSString *path) {
|
|||
|
||||
ga_i += ga_w/ga_n;
|
||||
|
||||
os_log_debug(OS_LOG_DEFAULT, "\nn_past_old = %d, n_past = %d, ga_i = %ld\n\n", n_past + bd, n_past, static_cast<long>(ga_i));
|
||||
os_log_debug(os_log_inst, "\nn_past_old = %d, n_past = %d, ga_i = %ld\n\n", n_past + bd, n_past, static_cast<long>(ga_i));
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -585,13 +592,13 @@ static BOOL file_is_empty(NSString *path) {
|
|||
}
|
||||
}
|
||||
|
||||
for (int i = 0; i < (int) embd.size(); i += params.nBatch) {
|
||||
for (int i = 0; i < (int) embd.size(); i += _params.nBatch) {
|
||||
int n_eval = (int) embd.size() - i;
|
||||
if (n_eval > params.nBatch) {
|
||||
n_eval = params.nBatch;
|
||||
if (n_eval > _params.nBatch) {
|
||||
n_eval = _params.nBatch;
|
||||
}
|
||||
|
||||
os_log_debug(OS_LOG_DEFAULT, "eval: %s\n", [self.ctx convertTokensToString:embd].c_str());
|
||||
os_log_debug(os_log_inst, "eval: %s\n", [self.ctx convertTokensToString:embd].c_str());
|
||||
|
||||
|
||||
if ([self.ctx decode:[[LlamaBatch alloc] initWithBatch:llama_batch_get_one(&embd[i], n_eval)] ]) {
|
||||
|
@ -600,10 +607,10 @@ static BOOL file_is_empty(NSString *path) {
|
|||
|
||||
n_past += n_eval;
|
||||
|
||||
os_log_debug(OS_LOG_DEFAULT, "n_past = %d\n", n_past);
|
||||
os_log_debug(os_log_inst, "n_past = %d\n", n_past);
|
||||
// Display total tokens alongside total time
|
||||
if (params.nPrint > 0 && n_past % params.nPrint == 0) {
|
||||
os_log_debug(OS_LOG_DEFAULT, "\n\033[31mTokens consumed so far = %d / %lu \033[0m\n", n_past, static_cast<unsigned long>([self.ctx nCtx]));
|
||||
if (_params.nPrint > 0 && n_past % _params.nPrint == 0) {
|
||||
os_log_debug(os_log_inst, "\n\033[31mTokens consumed so far = %d / %lu \033[0m\n", n_past, static_cast<unsigned long>([self.ctx nCtx]));
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -617,19 +624,19 @@ static BOOL file_is_empty(NSString *path) {
|
|||
|
||||
if ((int) embd_inp.size() <= n_consumed && !isInteracting) {
|
||||
// optionally save the session on first sample (for faster prompt loading next time)
|
||||
if ([pathSession length] > 0 && need_to_save_session && !params.promptCacheRO) {
|
||||
if ([pathSession length] > 0 && need_to_save_session && !_params.promptCacheRO) {
|
||||
need_to_save_session = false;
|
||||
[self.ctx saveStateFile:pathSession tokens:session_tokens.data() nTokenCount:session_tokens.size()];
|
||||
// llama_state_save_file(ctx, path_session.c_str(), session_tokens.data(), session_tokens.size());
|
||||
|
||||
os_log_debug(OS_LOG_DEFAULT, "saved session to %s\n", [pathSession cStringUsingEncoding:NSUTF8StringEncoding]);
|
||||
os_log_debug(os_log_inst, "saved session to %s\n", [pathSession cStringUsingEncoding:NSUTF8StringEncoding]);
|
||||
}
|
||||
|
||||
const llama_token idToken = [smpl sample:self.ctx index:-1];
|
||||
const llama_token idToken = [_smpl sample:self.ctx index:-1];
|
||||
|
||||
[smpl accept:idToken acceptGrammar:true];
|
||||
[_smpl accept:idToken acceptGrammar:true];
|
||||
|
||||
// os_log_debug(OS_LOG_DEFAULT, "last: %s\n", string_from(ctx, smpl->prev.to_vector()).c_str());
|
||||
// os_log_debug(os_log_inst, "last: %s\n", string_from(ctx, smpl->prev.to_vector()).c_str());
|
||||
|
||||
embd.push_back(idToken);
|
||||
|
||||
|
@ -639,19 +646,19 @@ static BOOL file_is_empty(NSString *path) {
|
|||
// decrement remaining sampling budget
|
||||
--n_remain;
|
||||
|
||||
os_log_debug(OS_LOG_DEFAULT, "n_remain: %d\n", n_remain);
|
||||
os_log_debug(os_log_inst, "n_remain: %d\n", n_remain);
|
||||
} else {
|
||||
// some user input remains from prompt or interaction, forward it to processing
|
||||
os_log_debug(OS_LOG_DEFAULT, "embd_inp.size(): %d, n_consumed: %d\n", (int) embd_inp.size(), n_consumed);
|
||||
os_log_debug(os_log_inst, "embd_inp.size(): %d, n_consumed: %d\n", (int) embd_inp.size(), n_consumed);
|
||||
while ((int) embd_inp.size() > n_consumed) {
|
||||
embd.push_back(embd_inp[n_consumed]);
|
||||
|
||||
// push the prompt in the sampling context in order to apply repetition penalties later
|
||||
// for the prompt, we don't apply grammar rules
|
||||
[smpl accept:embd_inp[n_consumed] acceptGrammar:false];
|
||||
[_smpl accept:embd_inp[n_consumed] acceptGrammar:false];
|
||||
|
||||
++n_consumed;
|
||||
if ((int) embd.size() >= params.nBatch) {
|
||||
if ((int) embd.size() >= _params.nBatch) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
@ -662,10 +669,10 @@ static BOOL file_is_empty(NSString *path) {
|
|||
// std::cout<< "DISPLAYING TEXT" << std::endl;
|
||||
|
||||
for (auto idToken : embd) {
|
||||
NSString *token_str = [self.ctx tokenToPiece:idToken special:params.special];
|
||||
NSString *token_str = [self.ctx tokenToPiece:idToken special:_params.special];
|
||||
|
||||
// Console/Stream Output
|
||||
os_log_info(OS_LOG_DEFAULT, "%s", [token_str cStringUsingEncoding:NSUTF8StringEncoding]);
|
||||
os_log_info(os_log_inst, "%s", [token_str cStringUsingEncoding:NSUTF8StringEncoding]);
|
||||
|
||||
// Record Displayed Tokens To Log
|
||||
// Note: Generated tokens are created one by one hence this check
|
||||
|
@ -678,6 +685,9 @@ static BOOL file_is_empty(NSString *path) {
|
|||
output_tokens.push_back(idToken);
|
||||
output_ss << [token_str cStringUsingEncoding:NSUTF8StringEncoding];
|
||||
last_output_ss << [token_str cStringUsingEncoding:NSUTF8StringEncoding];
|
||||
[self willChangeValueForKey:@"lastOutput"];
|
||||
[_mutableLastOutput appendString:token_str];
|
||||
[self didChangeValueForKey:@"lastOutput"];
|
||||
}
|
||||
|
||||
}
|
||||
|
@ -698,23 +708,23 @@ static BOOL file_is_empty(NSString *path) {
|
|||
// if not currently processing queued inputs;
|
||||
if ((int) embd_inp.size() <= n_consumed) {
|
||||
// check for reverse prompt in the last n_prev tokens
|
||||
if ([params.antiPrompts count] > 0) {
|
||||
if ([_params.antiPrompts count] > 0) {
|
||||
const int n_prev = 32;
|
||||
NSString *last_output = [smpl previousString:self.ctx n:n_prev];
|
||||
NSString *last_output = [_smpl previousString:self.ctx n:n_prev];
|
||||
|
||||
is_antiprompt = false;
|
||||
// Check if each of the reverse prompts appears at the end of the output.
|
||||
// If we're not running interactively, the reverse prompt might be tokenized with some following characters
|
||||
// so we'll compensate for that by widening the search window a bit.
|
||||
for (NSString *antiprompt in params.antiPrompts) {
|
||||
size_t extra_padding = params.interactive ? 0 : 2;
|
||||
for (NSString *antiprompt in _params.antiPrompts) {
|
||||
size_t extra_padding = _params.interactive ? 0 : 2;
|
||||
size_t search_start_pos = [last_output length] > static_cast<size_t>([antiprompt length] + extra_padding)
|
||||
? [last_output length] - static_cast<size_t>([antiprompt length] + extra_padding)
|
||||
: 0;
|
||||
|
||||
// TODO: Check if correct
|
||||
if ([last_output rangeOfString:antiprompt options:0 range:NSMakeRange(search_start_pos, last_output.length - search_start_pos)].location != NSNotFound) {
|
||||
if (params.interactive) {
|
||||
if (_params.interactive) {
|
||||
isInteracting = true;
|
||||
}
|
||||
is_antiprompt = true;
|
||||
|
@ -723,10 +733,10 @@ static BOOL file_is_empty(NSString *path) {
|
|||
}
|
||||
|
||||
// check for reverse prompt using special tokens
|
||||
llama_token last_token = [smpl last];
|
||||
llama_token last_token = [_smpl last];
|
||||
for (std::vector<llama_token> ids : antiprompt_ids) {
|
||||
if (ids.size() == 1 && last_token == ids[0]) {
|
||||
if (params.interactive) {
|
||||
if (_params.interactive) {
|
||||
isInteracting = true;
|
||||
}
|
||||
is_antiprompt = true;
|
||||
|
@ -735,25 +745,25 @@ static BOOL file_is_empty(NSString *path) {
|
|||
}
|
||||
|
||||
if (is_antiprompt) {
|
||||
os_log_debug(OS_LOG_DEFAULT, "found antiprompt: %s\n", [last_output cStringUsingEncoding:NSUTF8StringEncoding]);
|
||||
os_log_debug(os_log_inst, "found antiprompt: %s\n", [last_output cStringUsingEncoding:NSUTF8StringEncoding]);
|
||||
}
|
||||
}
|
||||
|
||||
// deal with end of generation tokens in interactive mode
|
||||
|
||||
if ([self.model tokenIsEOG:[smpl last]]) {
|
||||
os_log_debug(OS_LOG_DEFAULT, "found an EOG token\n");
|
||||
if ([self.model tokenIsEOG:[_smpl last]]) {
|
||||
os_log_debug(os_log_inst, "found an EOG token\n");
|
||||
|
||||
if (params.interactive) {
|
||||
if ([[params antiPrompts] count] > 0) {
|
||||
if (_params.interactive) {
|
||||
if ([[_params antiPrompts] count] > 0) {
|
||||
// tokenize and inject first reverse prompt
|
||||
|
||||
const auto first_antiprompt = [self.ctx tokenize:params.antiPrompts[0] addSpecial:false parseSpecial:true];
|
||||
const auto first_antiprompt = [self.ctx tokenize:_params.antiPrompts[0] addSpecial:false parseSpecial:true];
|
||||
embd_inp.insert(embd_inp.end(), first_antiprompt.begin(), first_antiprompt.end());
|
||||
is_antiprompt = true;
|
||||
}
|
||||
|
||||
if (params.enableChatTemplate) {
|
||||
if (_params.enableChatTemplate) {
|
||||
[self chat_add_and_format:chat_msgs
|
||||
role:"assistant"
|
||||
content:assistant_ss.str()];
|
||||
|
@ -764,32 +774,32 @@ static BOOL file_is_empty(NSString *path) {
|
|||
}
|
||||
|
||||
// if current token is not EOG, we add it to current assistant message
|
||||
if (params.conversation) {
|
||||
const auto idToken = [smpl last];
|
||||
if (_params.conversation) {
|
||||
const auto idToken = [_smpl last];
|
||||
assistant_ss << [[self.ctx tokenToPiece:idToken special:false] cStringUsingEncoding:NSUTF8StringEncoding];
|
||||
}
|
||||
|
||||
if (n_past > 0 && isInteracting) {
|
||||
os_log_debug(OS_LOG_DEFAULT, "waiting for user input\n");
|
||||
os_log_debug(os_log_inst, "waiting for user input\n");
|
||||
|
||||
if (params.conversation) {
|
||||
if (_params.conversation) {
|
||||
// osLog_("\n> ");
|
||||
}
|
||||
|
||||
if (params.inputPrefixBOS) {
|
||||
os_log_debug(OS_LOG_DEFAULT, "adding input prefix BOS token\n");
|
||||
if (_params.inputPrefixBOS) {
|
||||
os_log_debug(os_log_inst, "adding input prefix BOS token\n");
|
||||
embd_inp.push_back([self.model tokenBOS]);
|
||||
}
|
||||
|
||||
std::string buffer;
|
||||
if ([params.inputPrefix length] > 0 && !params.conversation) {
|
||||
os_log_debug(OS_LOG_DEFAULT, "appending input prefix: '%s'\n", [params.inputPrefix cStringUsingEncoding:NSUTF8StringEncoding]);
|
||||
os_log_info(OS_LOG_DEFAULT, "%s", [params.inputPrefix cStringUsingEncoding:NSUTF8StringEncoding]);
|
||||
if ([_params.inputPrefix length] > 0 && !_params.conversation) {
|
||||
os_log_debug(os_log_inst, "appending input prefix: '%s'\n", [_params.inputPrefix cStringUsingEncoding:NSUTF8StringEncoding]);
|
||||
os_log_info(os_log_inst, "%s", [_params.inputPrefix cStringUsingEncoding:NSUTF8StringEncoding]);
|
||||
}
|
||||
|
||||
// color user input only
|
||||
// console::set_display(console::user_input);
|
||||
display = params.displayPrompt;
|
||||
display = _params.displayPrompt;
|
||||
|
||||
std::string line;
|
||||
// bool another_line = true;
|
||||
|
@ -806,8 +816,12 @@ static BOOL file_is_empty(NSString *path) {
|
|||
auto str = last_output_ss.str();
|
||||
last_output_ss.str("");
|
||||
[queue addOutputLine:[NSString stringWithCString:str.c_str() encoding:NSUTF8StringEncoding]];
|
||||
[self willChangeValueForKey:@"lastOutput"];
|
||||
_mutableLastOutput = [[NSMutableString alloc] init];
|
||||
[self didChangeValueForKey:@"lastOutput"];
|
||||
|
||||
}
|
||||
|
||||
|
||||
buffer = [[queue inputLine] cStringUsingEncoding:NSUTF8StringEncoding];
|
||||
// do {
|
||||
// another_line = console::readline(line, params.multiline_input);
|
||||
|
@ -822,34 +836,34 @@ static BOOL file_is_empty(NSString *path) {
|
|||
// Entering a empty line lets the user pass control back
|
||||
if (buffer.length() > 1) {
|
||||
// append input suffix if any
|
||||
if ([params.inputSuffix length] > 0 && !params.conversation) {
|
||||
os_log_debug(OS_LOG_DEFAULT, "appending input suffix: '%s'\n", [params.inputSuffix cStringUsingEncoding:NSUTF8StringEncoding]);
|
||||
os_log_info(OS_LOG_DEFAULT, "%s", [params.inputSuffix cStringUsingEncoding:NSUTF8StringEncoding]);
|
||||
if ([[self params].inputSuffix length] > 0 && !_params.conversation) {
|
||||
os_log_debug(os_log_inst, "appending input suffix: '%s'\n", [_params.inputSuffix cStringUsingEncoding:NSUTF8StringEncoding]);
|
||||
os_log_info(os_log_inst, "%s", [_params.inputSuffix cStringUsingEncoding:NSUTF8StringEncoding]);
|
||||
}
|
||||
|
||||
os_log_debug(OS_LOG_DEFAULT, "buffer: '%s'\n", buffer.c_str());
|
||||
os_log_debug(os_log_inst, "buffer: '%s'\n", buffer.c_str());
|
||||
|
||||
const size_t original_size = embd_inp.size();
|
||||
|
||||
if (params.escapeSequences) {
|
||||
if (_params.escapeSequences) {
|
||||
string_process_escapes(buffer);
|
||||
}
|
||||
|
||||
bool format_chat = params.conversation && params.enableChatTemplate;
|
||||
bool format_chat = _params.conversation && _params.enableChatTemplate;
|
||||
std::string user_inp = format_chat
|
||||
? [[self chat_add_and_format:chat_msgs role:"user" content:std::move(buffer)] cStringUsingEncoding:NSUTF8StringEncoding]
|
||||
: std::move(buffer);
|
||||
// TODO: one inconvenient of current chat template implementation is that we can't distinguish between user input and special tokens (prefix/postfix)
|
||||
const auto line_pfx = [self.ctx tokenize:params.inputPrefix addSpecial:false parseSpecial:true];
|
||||
const auto line_pfx = [self.ctx tokenize:_params.inputPrefix addSpecial:false parseSpecial:true];
|
||||
const auto line_inp = [self.ctx tokenize:[NSString stringWithCString:user_inp.c_str()
|
||||
encoding:NSUTF8StringEncoding]
|
||||
addSpecial:false
|
||||
parseSpecial:format_chat];
|
||||
const auto line_sfx = [self.ctx tokenize:params.inputSuffix
|
||||
const auto line_sfx = [self.ctx tokenize:_params.inputSuffix
|
||||
addSpecial:false
|
||||
parseSpecial:true];
|
||||
|
||||
os_log_debug(OS_LOG_DEFAULT, "input tokens: %s\n", [self.ctx convertTokensToString:line_inp].c_str());
|
||||
os_log_debug(os_log_inst, "input tokens: %s\n", [self.ctx convertTokensToString:line_inp].c_str());
|
||||
|
||||
// if user stop generation mid-way, we must add EOT to finish model's last response
|
||||
if (need_insert_eot && format_chat) {
|
||||
|
@ -872,9 +886,9 @@ static BOOL file_is_empty(NSString *path) {
|
|||
assistant_ss.str("");
|
||||
|
||||
n_remain -= line_inp.size();
|
||||
os_log_debug(OS_LOG_DEFAULT, "n_remain: %d\n", n_remain);
|
||||
os_log_debug(os_log_inst, "n_remain: %d\n", n_remain);
|
||||
} else {
|
||||
os_log_debug(OS_LOG_DEFAULT, "empty line, passing control back\n");
|
||||
os_log_debug(os_log_inst, "empty line, passing control back\n");
|
||||
}
|
||||
|
||||
input_echo = false; // do not echo this again
|
||||
|
@ -882,22 +896,22 @@ static BOOL file_is_empty(NSString *path) {
|
|||
|
||||
if (n_past > 0) {
|
||||
if (isInteracting) {
|
||||
[smpl reset];
|
||||
[_smpl reset];
|
||||
}
|
||||
isInteracting = false;
|
||||
}
|
||||
}
|
||||
|
||||
// end of generation
|
||||
if (!embd.empty() && [self.model tokenIsEOG:embd.back()] && !(params.interactive)) {
|
||||
os_log_info(OS_LOG_DEFAULT, " [end of text]\n");
|
||||
if (!embd.empty() && [self.model tokenIsEOG:embd.back()] && !(_params.interactive)) {
|
||||
os_log_info(os_log_inst, " [end of text]\n");
|
||||
break;
|
||||
}
|
||||
|
||||
// In interactive mode, respect the maximum number of tokens and drop back to user input when reached.
|
||||
// We skip this logic when n_predict == -1 (infinite) or -2 (stop at context size).
|
||||
if (params.interactive && n_remain <= 0 && params.nPredict >= 0) {
|
||||
n_remain = params.nPredict;
|
||||
if (_params.interactive && n_remain <= 0 && _params.nPredict >= 0) {
|
||||
n_remain = _params.nPredict;
|
||||
isInteracting = true;
|
||||
}
|
||||
}
|
||||
|
|
30
objc/include/CPUParams.h
Normal file
30
objc/include/CPUParams.h
Normal file
|
@ -0,0 +1,30 @@
|
|||
#ifndef CPUParams_h
|
||||
#define CPUParams_h
|
||||
|
||||
typedef NS_ENUM(NSUInteger, GGMLSchedPriority);
|
||||
|
||||
@class GGMLThreadpoolParams;
|
||||
|
||||
@interface CPUParams : NSObject
|
||||
|
||||
/// Number of threads to use
|
||||
@property (nonatomic, assign) NSInteger nThreads;
|
||||
|
||||
/// Default: any CPU
|
||||
@property (nonatomic, assign) BOOL maskValid;
|
||||
/// Scheduling priority
|
||||
@property (nonatomic, assign) GGMLSchedPriority priority;
|
||||
/// Use strict CPU placement
|
||||
@property (nonatomic, assign) BOOL strictCPU;
|
||||
/// Polling (busywait) level (0 - no polling, 100 - mostly polling)
|
||||
@property (nonatomic, assign) NSUInteger poll;
|
||||
|
||||
// Custom methods to access or manipulate the cpumask array
|
||||
- (BOOL)getCpuMaskAtIndex:(NSUInteger)index;
|
||||
- (void)setCpuMask:(BOOL)value atIndex:(NSUInteger)index;
|
||||
|
||||
- (GGMLThreadpoolParams *)ggmlThreadpoolParams;
|
||||
|
||||
@end
|
||||
|
||||
#endif /* Header_h */
|
15
objc/include/CPUParams_Private.hpp
Normal file
15
objc/include/CPUParams_Private.hpp
Normal file
|
@ -0,0 +1,15 @@
|
|||
#ifndef CPUParams_Private_hpp
|
||||
#define CPUParams_Private_hpp
|
||||
|
||||
#import "CPUParams.h"
|
||||
#import "../../common/common.h"
|
||||
|
||||
@interface CPUParams() {
|
||||
cpu_params *params;
|
||||
}
|
||||
|
||||
- (instancetype)initWithParams:(cpu_params&)params;
|
||||
|
||||
@end
|
||||
|
||||
#endif /* CPUParams_Private_hpp */
|
29
objc/include/GGMLThreadpool.h
Normal file
29
objc/include/GGMLThreadpool.h
Normal file
|
@ -0,0 +1,29 @@
|
|||
#ifndef GGMLThreadpool_h
|
||||
#define GGMLThreadpool_h
|
||||
|
||||
typedef NS_ENUM(NSUInteger, GGMLSchedPriority) {
|
||||
GGMLSchedPriorityNormal = 0, // Normal priority
|
||||
GGMLSchedPriorityMedium = 1, // Medium priority
|
||||
GGMLSchedPriorityHigh = 2, // High priority
|
||||
GGMLSchedPriorityRealtime = 3 // Realtime priority
|
||||
};
|
||||
|
||||
|
||||
@interface GGMLThreadpool : NSObject
|
||||
@end
|
||||
|
||||
@interface GGMLThreadpoolParams : NSObject
|
||||
|
||||
@property (nonatomic, assign) int nThreads;
|
||||
@property (nonatomic, assign) GGMLSchedPriority priority;
|
||||
@property (nonatomic, assign) uint32_t poll;
|
||||
@property (nonatomic, assign) BOOL strictCPU;
|
||||
@property (nonatomic, assign) BOOL paused;
|
||||
|
||||
- (BOOL)getCpuMaskAtIndex:(NSUInteger)index;
|
||||
- (void)setCpuMask:(BOOL)value atIndex:(NSUInteger)index;
|
||||
- (GGMLThreadpool *)threadpool;
|
||||
|
||||
@end
|
||||
|
||||
#endif /* GGMLThreadpool_h */
|
22
objc/include/GGMLThreadpool_Private.hpp
Normal file
22
objc/include/GGMLThreadpool_Private.hpp
Normal file
|
@ -0,0 +1,22 @@
|
|||
#ifndef GGMLThreadpool_Private_hpp
|
||||
#define GGMLThreadpool_Private_hpp
|
||||
|
||||
#import "GGMLThreadpool.h"
|
||||
#import "../../common/common.h"
|
||||
|
||||
@interface GGMLThreadpool() {
|
||||
ggml_threadpool *threadpool;
|
||||
}
|
||||
|
||||
- (instancetype)initWithThreadpool:(ggml_threadpool *)threadpool;
|
||||
- (ggml_threadpool *)threadpool;
|
||||
|
||||
@end
|
||||
|
||||
@interface GGMLThreadpoolParams()
|
||||
|
||||
- (instancetype)initWithParams:(ggml_threadpool_params&&)params;
|
||||
|
||||
@end
|
||||
|
||||
#endif /* GGMLThreadpool_Private_hpp */
|
|
@ -4,48 +4,7 @@
|
|||
@class LlamaModelParams;
|
||||
@class LlamaContextParams;
|
||||
@class GGMLThreadpool;
|
||||
|
||||
// Define the ggml_sched_priority enum
|
||||
typedef NS_ENUM(NSInteger, GGMLSchedPriority) {
|
||||
GGMLSchedPriorityNormal = 0, // Normal priority
|
||||
GGMLSchedPriorityMedium = 1, // Medium priority
|
||||
GGMLSchedPriorityHigh = 2, // High priority
|
||||
GGMLSchedPriorityRealtime = 3 // Realtime priority
|
||||
};
|
||||
|
||||
@interface GGMLThreadpoolParams : NSObject
|
||||
|
||||
@property (nonatomic, assign) int nThreads;
|
||||
@property (nonatomic, assign) GGMLSchedPriority priority;
|
||||
@property (nonatomic, assign) uint32_t poll;
|
||||
@property (nonatomic, assign) BOOL strictCPU;
|
||||
@property (nonatomic, assign) BOOL paused;
|
||||
|
||||
// Custom access methods for the cpumask array
|
||||
- (BOOL)getCpuMaskAtIndex:(NSUInteger)index;
|
||||
- (void)setCpuMask:(BOOL)value atIndex:(NSUInteger)index;
|
||||
- (GGMLThreadpool *)threadpool;
|
||||
|
||||
@end
|
||||
|
||||
@interface GGMLThreadpool : NSObject
|
||||
@end
|
||||
|
||||
@interface CPUParams : NSObject
|
||||
|
||||
// Properties
|
||||
@property (nonatomic, assign) int nThreads;
|
||||
@property (nonatomic, assign) BOOL maskValid;
|
||||
@property (nonatomic, assign) GGMLSchedPriority priority;
|
||||
@property (nonatomic, assign) BOOL strictCPU;
|
||||
@property (nonatomic, assign) uint32_t poll;
|
||||
|
||||
// Custom methods to access or manipulate the cpumask array
|
||||
- (BOOL)getCpuMaskAtIndex:(NSUInteger)index;
|
||||
- (void)setCpuMask:(BOOL)value atIndex:(NSUInteger)index;
|
||||
- (GGMLThreadpoolParams *)ggmlThreadpoolParams;
|
||||
|
||||
@end
|
||||
@class CPUParams;
|
||||
|
||||
@interface GPTSamplerParams : NSObject
|
||||
|
||||
|
@ -72,13 +31,10 @@ typedef NS_ENUM(NSInteger, GGMLSchedPriority) {
|
|||
@property (nonatomic, assign) BOOL penalizeNl;
|
||||
@property (nonatomic, assign) BOOL ignoreEos;
|
||||
@property (nonatomic, assign) BOOL noPerf;
|
||||
@property (nonatomic, strong) NSArray<NSNumber *> *samplers;
|
||||
@property (nonatomic, copy) NSString *grammar;
|
||||
@property (nonatomic, strong) NSArray<NSNumber *> *logitBias;
|
||||
|
||||
// Arrays and Strings
|
||||
@property (nonatomic, strong) NSArray<NSNumber *> *samplers; // Samplers mapped to NSArray of NSNumber (for enums)
|
||||
@property (nonatomic, copy) NSString *grammar; // Grammar as NSString
|
||||
@property (nonatomic, strong) NSArray<NSNumber *> *logitBias; // Logit biases mapped to NSArray of NSNumber
|
||||
|
||||
// Method to print the parameters into a string
|
||||
- (NSString *)print;
|
||||
|
||||
@end
|
||||
|
@ -98,7 +54,7 @@ typedef NS_ENUM(NSInteger, GGMLSchedPriority) {
|
|||
@property (nonatomic, assign) int32_t nGpuLayers;
|
||||
@property (nonatomic, assign) int32_t nGpuLayersDraft;
|
||||
@property (nonatomic, assign) int32_t mainGpu;
|
||||
@property (nonatomic, strong) NSMutableArray<NSNumber *> *tensorSplit; // Fixed-size array, stays the same
|
||||
@property (nonatomic, strong) NSArray<NSNumber *> *tensorSplit;
|
||||
@property (nonatomic, assign) int32_t grpAttnN;
|
||||
@property (nonatomic, assign) int32_t grpAttnW;
|
||||
@property (nonatomic, assign) int32_t nPrint;
|
||||
|
@ -111,13 +67,11 @@ typedef NS_ENUM(NSInteger, GGMLSchedPriority) {
|
|||
@property (nonatomic, assign) int32_t yarnOrigCtx;
|
||||
@property (nonatomic, assign) float defragThold;
|
||||
|
||||
// You need to replace your C++ struct "cpu_params" with an Objective-C class or struct accordingly
|
||||
@property (nonatomic, strong) CPUParams *cpuParams;
|
||||
@property (nonatomic, strong) CPUParams *cpuParamsBatch;
|
||||
@property (nonatomic, strong) CPUParams *draftCpuParams;
|
||||
@property (nonatomic, strong) CPUParams *draftCpuParamsBatch;
|
||||
|
||||
// Callbacks (assuming they are blocks in Objective-C)
|
||||
@property (nonatomic, copy) void (^cbEval)(void *);
|
||||
@property (nonatomic, assign) void *cbEvalUserData;
|
||||
|
||||
|
@ -149,12 +103,10 @@ typedef NS_ENUM(NSInteger, GGMLSchedPriority) {
|
|||
@property (nonatomic, copy) NSString *logitsFile;
|
||||
@property (nonatomic, copy) NSString *rpcServers;
|
||||
|
||||
// Arrays in Objective-C are represented with `NSArray`
|
||||
@property (nonatomic, strong) NSArray<NSString *> *inputFiles;
|
||||
@property (nonatomic, strong) NSArray<NSString *> *antiPrompts;
|
||||
@property (nonatomic, strong) NSArray *kvOverrides;
|
||||
|
||||
// Boolean values (in Objective-C, use `BOOL`)
|
||||
@property (nonatomic, assign) BOOL loraInitWithoutApply;
|
||||
@property (nonatomic, strong) NSArray *loraAdapters;
|
||||
@property (nonatomic, strong) NSArray *controlVectors;
|
||||
|
@ -257,6 +209,8 @@ typedef NS_ENUM(NSInteger, GGMLSchedPriority) {
|
|||
@property (nonatomic, assign) BOOL ctxShift; // context shift on inifinite text generation
|
||||
@property (nonatomic, assign) BOOL displayPrompt; // print prompt before generation
|
||||
|
||||
@property (nonatomic, assign) BOOL logging; // print logging
|
||||
|
||||
@end
|
||||
|
||||
#endif /* GPTParams_h */
|
||||
|
|
|
@ -5,11 +5,6 @@
|
|||
#import "ggml.h"
|
||||
#import "../../common/common.h"
|
||||
|
||||
@interface GGMLThreadpool()
|
||||
|
||||
- (ggml_threadpool *)threadpool;
|
||||
|
||||
@end
|
||||
|
||||
@interface GPTParams()
|
||||
|
||||
|
|
|
@ -38,14 +38,14 @@ typedef int32_t LlamaToken;
|
|||
index:(NSInteger) index
|
||||
grammarFirst:(BOOL)grammarFirst;
|
||||
|
||||
// if accept_grammar is true, the token is accepted both by the sampling chain and the grammar
|
||||
/// If accept_grammar is true, the token is accepted both by the sampling chain and the grammar
|
||||
- (void)accept:(LlamaToken)token
|
||||
acceptGrammar:(BOOL)acceptGrammar;
|
||||
|
||||
// get a string representation of the last accepted tokens
|
||||
/// Get a string representation of the last accepted tokens
|
||||
- (NSString *)previousString:(LlamaContext *)context n:(NSInteger)n;
|
||||
|
||||
// get the last accepted token
|
||||
/// Get the last accepted token
|
||||
- (LlamaToken)last;
|
||||
|
||||
- (void)reset;
|
||||
|
|
|
@ -5,15 +5,15 @@ typedef NSInteger LlamaSequenceId;
|
|||
typedef NSInteger LlamaPosition;
|
||||
typedef int32_t LlamaToken;
|
||||
|
||||
// Input data for llama_decode
|
||||
// A llama_batch object can contain input about one or many sequences
|
||||
// The provided arrays (i.e. token, embd, pos, etc.) must have size of n_tokens
|
||||
//
|
||||
// - token : the token ids of the input (used when embd is NULL)
|
||||
// - embd : token embeddings (i.e. float vector of size n_embd) (used when token is NULL)
|
||||
// - pos : the positions of the respective token in the sequence
|
||||
// - seq_id : the sequence to which the respective token belongs
|
||||
// - logits : if zero, the logits (and/or the embeddings) for the respective token will not be output
|
||||
/// Input data for llama_decode
|
||||
/// A llama_batch object can contain input about one or many sequences
|
||||
/// The provided arrays (i.e. token, embd, pos, etc.) must have size of n_tokens
|
||||
///
|
||||
/// - token : the token ids of the input (used when embd is NULL)
|
||||
/// - embd : token embeddings (i.e. float vector of size n_embd) (used when token is NULL)
|
||||
/// - pos : the positions of the respective token in the sequence
|
||||
/// - seq_id : the sequence to which the respective token belongs
|
||||
/// - logits : if zero, the logits (and/or the embeddings) for the respective token will not be output
|
||||
@interface LlamaBatch : NSObject
|
||||
|
||||
@property (nonatomic, assign) NSInteger nTokens;
|
||||
|
|
|
@ -2,6 +2,8 @@
|
|||
#define LlamaObjC_h
|
||||
|
||||
#include <Foundation/Foundation.h>
|
||||
#include <CPUParams.h>
|
||||
#include <GGMLThreadpool.h>
|
||||
#include <GPTParams.h>
|
||||
#include <GPTSampler.h>
|
||||
#include <llama.h>
|
||||
|
|
|
@ -14,10 +14,11 @@
|
|||
|
||||
@end
|
||||
|
||||
@interface LlamaSession : NSObject
|
||||
NS_REFINED_FOR_SWIFT @interface LlamaSession : NSObject
|
||||
|
||||
@property (nonatomic, strong) LlamaModel *model;
|
||||
@property (nonatomic, strong) LlamaContext *ctx;
|
||||
@property (nonatomic, strong, readonly) NSString *lastOutput;
|
||||
|
||||
- (instancetype)initWithParams:(GPTParams *)params;
|
||||
- (void)start:(BlockingLineQueue *)queue;
|
||||
|
|
|
@ -3,8 +3,14 @@
|
|||
|
||||
#import "LlamaSession.h"
|
||||
|
||||
@class GPTSampler;
|
||||
|
||||
@interface LlamaSession()
|
||||
|
||||
@property (atomic, strong) NSMutableString *mutableLastOutput;
|
||||
@property (nonatomic, strong) GPTParams *params;
|
||||
@property (nonatomic, strong) GPTSampler *smpl;
|
||||
|
||||
@end
|
||||
|
||||
#endif /* LlamaSession_Private_hpp */
|
||||
|
|
|
@ -24,8 +24,6 @@ public class SchemaConverter {
|
|||
}
|
||||
|
||||
private func formatLiteral(_ literal: Any) -> String {
|
||||
// let escaped = GRAMMAR_LITERAL_ESCAPES.reduce("\(literal)", {
|
||||
// let regex = Regex("[\r\n\"]")
|
||||
let escaped = GRAMMAR_LITERAL_ESCAPES.reduce("\(literal)") {
|
||||
$0.replacingOccurrences(of: $1.key, with: $1.value)
|
||||
}
|
||||
|
|
|
@ -57,6 +57,28 @@ public struct _JSONFunctionSchema: Codable {
|
|||
self.enum = nil
|
||||
}
|
||||
|
||||
public init(type: Int.Type, description: String?) {
|
||||
self.type = "integer"
|
||||
self.description = description
|
||||
self.items = nil
|
||||
self.enum = nil
|
||||
}
|
||||
|
||||
|
||||
public init(type: Double.Type, description: String?) {
|
||||
self.type = "number"
|
||||
self.description = description
|
||||
self.items = nil
|
||||
self.enum = nil
|
||||
}
|
||||
|
||||
public init(type: Bool.Type, description: String?) {
|
||||
self.type = "boolean"
|
||||
self.description = description
|
||||
self.items = nil
|
||||
self.enum = nil
|
||||
}
|
||||
|
||||
public init<T: CaseIterable>(type: T.Type, description: String?) where T: RawRepresentable,
|
||||
T: StringProtocol {
|
||||
self.type = "string"
|
||||
|
@ -146,6 +168,14 @@ extension Double : JSONSchemaConvertible {
|
|||
]
|
||||
}
|
||||
}
|
||||
extension Bool : JSONSchemaConvertible {
|
||||
public static var type: String { "boolean" }
|
||||
public static var jsonSchema: [String: Any] {
|
||||
[
|
||||
"type": "boolean"
|
||||
]
|
||||
}
|
||||
}
|
||||
extension Date : JSONSchemaConvertible {
|
||||
public static var type: String { "string" }
|
||||
|
||||
|
|
|
@ -7,11 +7,63 @@ public protocol DynamicCallable: Sendable {
|
|||
func dynamicallyCall(withKeywordArguments args: [String: Any]) async throws -> String
|
||||
}
|
||||
|
||||
public enum AnyDecodable: Decodable {
|
||||
case string(String)
|
||||
case int(Int)
|
||||
case double(Double)
|
||||
case bool(Bool)
|
||||
case null
|
||||
// Add other cases as needed
|
||||
|
||||
// Initializers for each type
|
||||
init(_ value: String) {
|
||||
self = .string(value)
|
||||
}
|
||||
|
||||
init(_ value: Int) {
|
||||
self = .int(value)
|
||||
}
|
||||
|
||||
init(_ value: Double) {
|
||||
self = .double(value)
|
||||
}
|
||||
|
||||
init(_ value: Bool) {
|
||||
self = .bool(value)
|
||||
}
|
||||
|
||||
init() {
|
||||
self = .null
|
||||
}
|
||||
|
||||
// Decodable conformance
|
||||
public init(from decoder: Decoder) throws {
|
||||
let container = try decoder.singleValueContainer()
|
||||
|
||||
if container.decodeNil() {
|
||||
self = .null
|
||||
} else if let intValue = try? container.decode(Int.self) {
|
||||
self = .int(intValue)
|
||||
} else if let doubleValue = try? container.decode(Double.self) {
|
||||
self = .double(doubleValue)
|
||||
} else if let boolValue = try? container.decode(Bool.self) {
|
||||
self = .bool(boolValue)
|
||||
} else if let stringValue = try? container.decode(String.self) {
|
||||
self = .string(stringValue)
|
||||
} else {
|
||||
let context = DecodingError.Context(
|
||||
codingPath: decoder.codingPath,
|
||||
debugDescription: "Cannot decode AnyDecodable"
|
||||
)
|
||||
throw DecodingError.typeMismatch(AnyDecodable.self, context)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct ToolCall: Decodable {
|
||||
let id: Int
|
||||
let name: String
|
||||
let arguments: [String: String]
|
||||
let arguments: [String: AnyDecodable]
|
||||
}
|
||||
|
||||
struct ToolResponse<T: Encodable>: Encodable {
|
||||
|
@ -23,10 +75,13 @@ struct ToolResponse<T: Encodable>: Encodable {
|
|||
/// Standard chat session for a given LLM.
|
||||
public actor LlamaChatSession {
|
||||
private let queue = BlockingLineQueue()
|
||||
private let session: LlamaObjC.LlamaSession
|
||||
private let session: __LlamaSession
|
||||
|
||||
/// Initialize the session
|
||||
/// - parameter params: common parameters to initialize the session
|
||||
/// - parameter flush: whether or not to flush the initial prompt, reading initial output
|
||||
public init(params: GPTParams, flush: Bool = true) async throws {
|
||||
session = LlamaObjC.LlamaSession(params: params)
|
||||
self.session = __LlamaSession(params: params)
|
||||
Task.detached { [session, queue] in
|
||||
session.start(queue)
|
||||
}
|
||||
|
@ -36,7 +91,36 @@ public actor LlamaChatSession {
|
|||
_ = queue.outputLine()
|
||||
}
|
||||
|
||||
public func chat(message: String) async -> String {
|
||||
/// Create a new inference stream for a given message
|
||||
/// - parameter message: The message to receive an inference for.
|
||||
/// - returns: A stream of output from the LLM.
|
||||
public func inferenceStream(message: String) async -> AsyncStream<String> {
|
||||
queue.addInputLine(message)
|
||||
var observationToken: NSKeyValueObservation?
|
||||
return AsyncStream { stream in
|
||||
observationToken = self.session.observe(\.lastOutput, options: [.new, .old]) { session, change in
|
||||
guard let newValue = change.newValue,
|
||||
let oldValue = change.oldValue else {
|
||||
return stream.finish()
|
||||
}
|
||||
var delta = ""
|
||||
for change in newValue!.difference(from: oldValue!) {
|
||||
switch change {
|
||||
case .remove(_, _, _):
|
||||
return stream.finish()
|
||||
case .insert(_, let element, _):
|
||||
delta.append(element)
|
||||
}
|
||||
}
|
||||
stream.yield(delta)
|
||||
}
|
||||
stream.onTermination = { [observationToken] _ in
|
||||
observationToken?.invalidate()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
public func infer(message: String) async -> String {
|
||||
queue.addInputLine(message)
|
||||
return queue.outputLine()
|
||||
}
|
||||
|
@ -46,15 +130,15 @@ public actor LlamaChatSession {
|
|||
public actor LlamaSession<T: JSONSchemaConvertible> {
|
||||
private let session: LlamaChatSession
|
||||
|
||||
public init(params: GPTParams) async throws {
|
||||
public init(params: GPTParams, flush: Bool = true) async throws {
|
||||
let converter = SchemaConverter(propOrder: [])
|
||||
_ = converter.visit(schema: T.jsonSchema, name: nil)
|
||||
params.samplerParams.grammar = converter.formatGrammar()
|
||||
session = try await LlamaChatSession(params: params)
|
||||
session = try await LlamaChatSession(params: params, flush: flush)
|
||||
}
|
||||
|
||||
public func chat(message: String) async throws -> T {
|
||||
let output = await session.chat(message: message).data(using: .utf8)!
|
||||
let output = await session.infer(message: message).data(using: .utf8)!
|
||||
return try JSONDecoder().decode(T.self, from: output)
|
||||
}
|
||||
}
|
||||
|
@ -117,11 +201,14 @@ public actor LlamaToolSession {
|
|||
self.tools["getIpAddress"] = (GetIpAddress(), ipFnSchema)
|
||||
let encoded = try JSONEncoder().encode(self.tools.values.map(\.1))
|
||||
let prompt = """
|
||||
\(params.prompt ?? "")
|
||||
|
||||
You are a function calling AI model. You are provided with function signatures within <tools></tools> XML tags. You may call one or more functions to assist with the user query. Don't make assumptions about what values to plug into functions. For each function call return a json object with function name and arguments within <tool_call></tool_call> XML tags as follows:
|
||||
<tool_call>
|
||||
{"name": <function-name>,"arguments": <args-dict>}
|
||||
</tool_call>
|
||||
|
||||
Feel free to chain tool calls, e.g., if you need the user's location to find points of interest near them, fetch the user's location first.
|
||||
The first call you will be asked to warm up is to get the user's IP address. Here are the available tools:
|
||||
<tools> \(String(data: encoded, encoding: .utf8)!) </tools><|eot_id|>
|
||||
"""
|
||||
|
@ -131,7 +218,7 @@ public actor LlamaToolSession {
|
|||
params.inputPrefix = "<|start_header_id|>user<|end_header_id|>";
|
||||
params.inputSuffix = "<|eot_id|><|start_header_id|>assistant<|end_header_id|>";
|
||||
session = try await LlamaChatSession(params: params, flush: false)
|
||||
let fn = await session.chat(message: "What is my IP address?")
|
||||
let fn = await session.infer(message: "What is my IP address?")
|
||||
let toolCall = try JSONDecoder().decode(ToolCall.self, from: fn.data(using: .utf8)!)
|
||||
guard let tool = self.tools[toolCall.name] else {
|
||||
fatalError()
|
||||
|
@ -139,7 +226,7 @@ public actor LlamaToolSession {
|
|||
let resp = try await tool.0.dynamicallyCall(withKeywordArguments: toolCall.arguments)
|
||||
print(resp)
|
||||
|
||||
let output = await session.chat(message: """
|
||||
let output = await session.infer(message: """
|
||||
<tool_response>
|
||||
{"id": \(toolCall.id), result: \(resp)}
|
||||
</tool_response>
|
||||
|
@ -147,39 +234,58 @@ public actor LlamaToolSession {
|
|||
print(output)
|
||||
}
|
||||
|
||||
public func chat(message: String) async throws -> String {
|
||||
var nxt = await session.chat(message: message)
|
||||
let fn = nxt
|
||||
// try to see if the output is a function call
|
||||
private func callTool(_ call: String) async -> String? {
|
||||
var nxt: String?
|
||||
do {
|
||||
let toolCall = try JSONDecoder().decode(ToolCall.self, from: fn.data(using: .utf8)!)
|
||||
let toolCall = try JSONDecoder().decode(ToolCall.self, from: call.data(using: .utf8)!)
|
||||
guard let tool = tools[toolCall.name] else {
|
||||
fatalError()
|
||||
}
|
||||
// TODO: tool call decode is allowed to fail but the code below is not
|
||||
let callable = tool.0
|
||||
let resp = try await callable.dynamicallyCall(withKeywordArguments: toolCall.arguments)
|
||||
print("tool response: \(resp)")
|
||||
nxt = await session.chat(message: """
|
||||
<tool_response>
|
||||
{"id": \(toolCall.id), result: \(resp)}
|
||||
</tool_response>
|
||||
""")
|
||||
print(nxt)
|
||||
} catch {
|
||||
print(error)
|
||||
}
|
||||
|
||||
do {
|
||||
let response = try await callable.dynamicallyCall(withKeywordArguments: toolCall.arguments)
|
||||
print("tool response: \(response)")
|
||||
nxt = await session.infer(message: """
|
||||
<tool_response>
|
||||
{"id": \(toolCall.id), result: \(response)}
|
||||
</tool_response>
|
||||
""")
|
||||
// TODO: If this decodes correctly, we should tail this into this method
|
||||
// TODO: so that we do not decode twice
|
||||
if let _ = try? JSONDecoder().decode(ToolCall.self, from: nxt!.data(using: .utf8)!) {
|
||||
return await callTool(nxt!)
|
||||
}
|
||||
} catch {
|
||||
nxt = await session.infer(message: """
|
||||
<tool_response>
|
||||
{"id": \(toolCall.id), result: "The tool call has unfortunately failed."}
|
||||
</tool_response>
|
||||
""")
|
||||
}
|
||||
print(nxt ?? "nil")
|
||||
} catch {}
|
||||
return nxt
|
||||
}
|
||||
|
||||
public func infer(message: String) async throws -> String {
|
||||
let output = await session.infer(message: message)
|
||||
guard let output = await callTool(output) else {
|
||||
return output
|
||||
}
|
||||
return output
|
||||
}
|
||||
}
|
||||
|
||||
public protocol LlamaActor: Actor {
|
||||
static var tools: [String: (DynamicCallable, _JSONFunctionSchema)] { get }
|
||||
var session: LlamaToolSession { get }
|
||||
static func tools(_ self: Self) -> [String: (DynamicCallable, _JSONFunctionSchema)]
|
||||
var session: LlamaToolSession! { get }
|
||||
}
|
||||
|
||||
public extension LlamaActor {
|
||||
func chat(_ message: String) async throws -> String {
|
||||
try await session.chat(message: message)
|
||||
try await session.infer(message: message)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -16,12 +16,14 @@ struct ToolMacro: BodyMacro {
|
|||
|
||||
struct LlamaActorMacro: ExtensionMacro, MemberMacro {
|
||||
static func expansion(of node: AttributeSyntax, providingMembersOf declaration: some DeclGroupSyntax, conformingTo protocols: [TypeSyntax], in context: some MacroExpansionContext) throws -> [DeclSyntax] {
|
||||
[
|
||||
return [
|
||||
"""
|
||||
let session: LlamaToolSession
|
||||
var session: LlamaToolSession!
|
||||
|
||||
public init(params: GPTParams) async throws {
|
||||
self.session = try await LlamaToolSession(params: params, tools: Self.tools)
|
||||
\(raw: declaration.inheritanceClause != nil ? "self.init()" : "")
|
||||
let tools = Self.tools(self)
|
||||
self.session = try await LlamaToolSession(params: params, tools: tools)
|
||||
}
|
||||
"""
|
||||
]
|
||||
|
@ -41,6 +43,7 @@ struct LlamaActorMacro: ExtensionMacro, MemberMacro {
|
|||
callableString: String,
|
||||
callableName: String)
|
||||
] = []
|
||||
let typeName = type.as(IdentifierTypeSyntax.self)!.name.text
|
||||
for member in declaration.memberBlock.members {
|
||||
let comments = member.leadingTrivia.filter { $0.isComment }
|
||||
guard let member = member.decl.as(FunctionDeclSyntax.self) else {
|
||||
|
@ -72,6 +75,11 @@ struct LlamaActorMacro: ExtensionMacro, MemberMacro {
|
|||
let callableName = context.makeUniqueName(name.text)
|
||||
let callableString = """
|
||||
@dynamicCallable struct \(callableName.text): DynamicCallable {
|
||||
private weak var llamaActor: \(typeName)?
|
||||
init(_ llamaActor: \(typeName)) {
|
||||
self.llamaActor = llamaActor
|
||||
}
|
||||
|
||||
@discardableResult
|
||||
func dynamicallyCall(withKeywordArguments args: [String: Any]) async throws -> String {
|
||||
\(parameters.map {
|
||||
|
@ -79,11 +87,15 @@ struct LlamaActorMacro: ExtensionMacro, MemberMacro {
|
|||
}.joined(separator: "\n"))
|
||||
for (key, value) in args {
|
||||
\(parameters.map {
|
||||
"if key == \"\($0.name)\" { \($0.name) = value as! \($0.type) }"
|
||||
"""
|
||||
if key == "\($0.name)", let value = value as? AnyDecodable, case let .\($0.type.lowercased())(v) = value {
|
||||
\($0.name) = v
|
||||
}
|
||||
"""
|
||||
}.joined(separator: "\n"))
|
||||
}
|
||||
|
||||
let returnValue = try await \(name.text)(\(parameters.map { "\($0.name): \($0.name)" }.joined(separator: ",")))
|
||||
let returnValue = try await self.llamaActor!.\(name.text)(\(parameters.map { "\($0.name): \($0.name)" }.joined(separator: ",")))
|
||||
let jsonValue = try JSONEncoder().encode(returnValue)
|
||||
return String(data: jsonValue, encoding: .utf8)!
|
||||
}
|
||||
|
@ -105,10 +117,10 @@ struct LlamaActorMacro: ExtensionMacro, MemberMacro {
|
|||
$0.callableString
|
||||
}.joined(separator: "\n"))
|
||||
|
||||
static var tools: [String: (DynamicCallable, _JSONFunctionSchema)] {
|
||||
static func tools(_ self: \(raw: typeName)) -> [String: (DynamicCallable, _JSONFunctionSchema)] {
|
||||
[\(raw: tools.map { tool in
|
||||
"""
|
||||
"\(tool.name)": (\(tool.callableName)(), _JSONFunctionSchema(name: "\(tool.name)", description: "\(tool.description)", parameters: _JSONFunctionSchema.Parameters(properties: \(tool.parameters.count == 0 ? "[:]" : "[" + tool.parameters.map { parameter in
|
||||
"\(tool.name)": (\(tool.callableName)(self), _JSONFunctionSchema(name: "\(tool.name)", description: "\(tool.description)", parameters: _JSONFunctionSchema.Parameters(properties: \(tool.parameters.count == 0 ? "[:]" : "[" + tool.parameters.map { parameter in
|
||||
"""
|
||||
"\(parameter.name)": _JSONFunctionSchema.Property(type: \(parameter.type).self, description: "\(parameter.description)"),
|
||||
"""
|
||||
|
|
|
@ -32,7 +32,7 @@ func downloadFile() async throws -> String {
|
|||
|
||||
@llamaActor actor MyLlama {
|
||||
/// Get the current date.
|
||||
@Tool public static func getCurrentDate() -> String {
|
||||
@Tool public func getCurrentDate() -> String {
|
||||
Date.now.formatted(date: .long, time: .complete)
|
||||
}
|
||||
}
|
||||
|
|
370
swift/test/GPTParamsTests.swift
Normal file
370
swift/test/GPTParamsTests.swift
Normal file
|
@ -0,0 +1,370 @@
|
|||
import Foundation
|
||||
import XCTest
|
||||
|
||||
import LlamaKit // Replace with your module name
|
||||
|
||||
final class GPTParamsTests: XCTestCase {
|
||||
func testPropertyAssignmentsAndCopy() throws {
|
||||
// Create an instance of GPTParams
|
||||
let originalParams = GPTParams()
|
||||
|
||||
// Assign values to all properties
|
||||
originalParams.nPredict = 10
|
||||
originalParams.nCtx = 20
|
||||
originalParams.nBatch = 30
|
||||
originalParams.nUBatch = 40
|
||||
originalParams.nKeep = 50
|
||||
originalParams.nDraft = 60
|
||||
originalParams.nChunks = 70
|
||||
originalParams.nParallel = 80
|
||||
originalParams.nSequences = 90
|
||||
originalParams.pSplit = 0.5
|
||||
originalParams.nGpuLayers = 100
|
||||
originalParams.nGpuLayersDraft = 110
|
||||
originalParams.mainGpu = 120
|
||||
originalParams.tensorSplit = [0.1, 0.2, 0.3]
|
||||
originalParams.grpAttnN = 130
|
||||
originalParams.grpAttnW = 140
|
||||
originalParams.nPrint = 150
|
||||
originalParams.ropeFreqBase = 0.6
|
||||
originalParams.ropeFreqScale = 0.7
|
||||
originalParams.yarnExtFactor = 0.8
|
||||
originalParams.yarnAttnFactor = 0.9
|
||||
originalParams.yarnBetaFast = 1.0
|
||||
originalParams.yarnBetaSlow = 1.1
|
||||
originalParams.yarnOrigCtx = 160
|
||||
originalParams.defragThold = 1.2
|
||||
|
||||
// Initialize CPUParams instances if needed
|
||||
originalParams.cpuParams = CPUParams()
|
||||
originalParams.cpuParamsBatch = CPUParams()
|
||||
originalParams.draftCpuParams = CPUParams()
|
||||
originalParams.draftCpuParamsBatch = CPUParams()
|
||||
|
||||
// Assign blocks and user data
|
||||
originalParams.cbEval = { userData in
|
||||
// Callback implementation
|
||||
}
|
||||
originalParams.cbEvalUserData = nil
|
||||
|
||||
// Assign enum values (assuming NSInteger maps to Int)
|
||||
originalParams.numaStrategy = 1
|
||||
originalParams.splitMode = 2
|
||||
originalParams.ropeScalingType = 3
|
||||
originalParams.poolingType = 4
|
||||
originalParams.attentionType = 5
|
||||
|
||||
// Assign sampler parameters
|
||||
originalParams.samplerParams = GPTSamplerParams()
|
||||
|
||||
// Assign string properties
|
||||
originalParams.modelPath = "path/to/model"
|
||||
originalParams.modelDraft = "model_draft"
|
||||
originalParams.modelAlias = "alias"
|
||||
originalParams.modelURL = "http://model.url"
|
||||
originalParams.hfToken = "token"
|
||||
originalParams.hfRepo = "repo"
|
||||
originalParams.hfFile = "file"
|
||||
originalParams.prompt = "prompt"
|
||||
originalParams.promptFile = "prompt.txt"
|
||||
originalParams.pathPromptCache = "cache/path"
|
||||
originalParams.inputPrefix = "prefix"
|
||||
originalParams.inputSuffix = "suffix"
|
||||
originalParams.logdir = "log/dir"
|
||||
originalParams.lookupCacheStatic = "static/cache"
|
||||
originalParams.lookupCacheDynamic = "dynamic/cache"
|
||||
originalParams.logitsFile = "logits.txt"
|
||||
originalParams.rpcServers = "servers"
|
||||
|
||||
// Assign array properties
|
||||
originalParams.inputFiles = ["input1.txt", "input2.txt"]
|
||||
originalParams.antiPrompts = ["anti1", "anti2"]
|
||||
originalParams.kvOverrides = ["override1", "override2"]
|
||||
originalParams.loraAdapters = ["adapter1", "adapter2"]
|
||||
originalParams.controlVectors = ["control1", "control2"]
|
||||
|
||||
// Assign boolean and control properties
|
||||
originalParams.loraInitWithoutApply = true
|
||||
originalParams.verbosity = 1
|
||||
originalParams.controlVectorLayerStart = 2
|
||||
originalParams.controlVectorLayerEnd = 3
|
||||
originalParams.pplStride = 4
|
||||
originalParams.pplOutputType = 5
|
||||
originalParams.hellaswag = true
|
||||
originalParams.hellaswagTasks = 10
|
||||
originalParams.winogrande = false
|
||||
originalParams.winograndeTasks = 20
|
||||
originalParams.multipleChoice = true
|
||||
originalParams.multipleChoiceTasks = 30
|
||||
originalParams.klDivergence = false
|
||||
originalParams.usage = true
|
||||
originalParams.useColor = false
|
||||
originalParams.special = true
|
||||
originalParams.interactive = false
|
||||
originalParams.interactiveFirst = true
|
||||
originalParams.conversation = false
|
||||
originalParams.promptCacheAll = true
|
||||
originalParams.promptCacheRO = false
|
||||
originalParams.escapeSequences = true
|
||||
originalParams.multilineInput = false
|
||||
originalParams.simpleIO = true
|
||||
originalParams.continuousBatching = false
|
||||
originalParams.flashAttention = true
|
||||
originalParams.noPerformanceMetrics = false
|
||||
originalParams.contextShift = true
|
||||
|
||||
// Server and I/O settings
|
||||
originalParams.port = 8080
|
||||
originalParams.timeoutRead = 60
|
||||
originalParams.timeoutWrite = 30
|
||||
originalParams.httpThreads = 4
|
||||
originalParams.hostname = "localhost"
|
||||
originalParams.publicPath = "/public"
|
||||
originalParams.chatTemplate = "template"
|
||||
originalParams.systemPrompt = "system prompt"
|
||||
originalParams.enableChatTemplate = true
|
||||
originalParams.apiKeys = ["key1", "key2"]
|
||||
originalParams.sslFileKey = "key.pem"
|
||||
originalParams.sslFileCert = "cert.pem"
|
||||
originalParams.endpointSlots = true
|
||||
originalParams.endpointMetrics = false
|
||||
originalParams.logJSON = true
|
||||
originalParams.slotSavePath = "/slots"
|
||||
originalParams.slotPromptSimilarity = 0.75
|
||||
|
||||
// Batched-bench params
|
||||
originalParams.isPPShared = true
|
||||
originalParams.nPP = [1, 2]
|
||||
originalParams.nTG = [3, 4]
|
||||
originalParams.nPL = [5, 6]
|
||||
|
||||
// Retrieval params
|
||||
originalParams.contextFiles = ["context1.txt", "context2.txt"]
|
||||
originalParams.chunkSize = 1024
|
||||
originalParams.chunkSeparator = "\n"
|
||||
|
||||
// Passkey params
|
||||
originalParams.nJunk = 7
|
||||
originalParams.iPos = 8
|
||||
|
||||
// Imatrix params
|
||||
originalParams.outFile = "output.txt"
|
||||
originalParams.nOutFreq = 100
|
||||
originalParams.nSaveFreq = 200
|
||||
originalParams.iChunk = 9
|
||||
originalParams.processOutput = true
|
||||
originalParams.computePPL = false
|
||||
|
||||
// Cvector-generator params
|
||||
originalParams.nPCABatch = 10
|
||||
originalParams.nPCAIterations = 11
|
||||
originalParams.cvectorDimreMethod = 12
|
||||
originalParams.cvectorOutfile = "cvector.out"
|
||||
originalParams.cvectorPositiveFile = "positive.txt"
|
||||
originalParams.cvectorNegativeFile = "negative.txt"
|
||||
|
||||
// Additional properties
|
||||
originalParams.spmInfill = true
|
||||
originalParams.loraOutfile = "lora.out"
|
||||
originalParams.embedding = false
|
||||
originalParams.verbosePrompt = true
|
||||
originalParams.batchedBenchOutputJSONL = false
|
||||
originalParams.inputPrefixBOS = true
|
||||
originalParams.ctxShift = false
|
||||
originalParams.displayPrompt = true
|
||||
originalParams.logging = false
|
||||
|
||||
// Verify that properties are assigned correctly
|
||||
XCTAssertEqual(originalParams.nPredict, 10)
|
||||
XCTAssertEqual(originalParams.nCtx, 20)
|
||||
XCTAssertEqual(originalParams.nBatch, 30)
|
||||
XCTAssertEqual(originalParams.nUBatch, 40)
|
||||
XCTAssertEqual(originalParams.nKeep, 50)
|
||||
XCTAssertEqual(originalParams.nDraft, 60)
|
||||
XCTAssertEqual(originalParams.nChunks, 70)
|
||||
XCTAssertEqual(originalParams.nParallel, 80)
|
||||
XCTAssertEqual(originalParams.nSequences, 90)
|
||||
XCTAssertEqual(originalParams.pSplit, 0.5)
|
||||
XCTAssertEqual(originalParams.nGpuLayers, 100)
|
||||
XCTAssertEqual(originalParams.nGpuLayersDraft, 110)
|
||||
XCTAssertEqual(originalParams.mainGpu, 120)
|
||||
XCTAssertEqual(originalParams.tensorSplit[0..<3].map(\.floatValue),
|
||||
[0.1, 0.2, 0.3])
|
||||
XCTAssertEqual(originalParams.grpAttnN, 130)
|
||||
XCTAssertEqual(originalParams.grpAttnW, 140)
|
||||
XCTAssertEqual(originalParams.nPrint, 150)
|
||||
XCTAssertEqual(originalParams.ropeFreqBase, 0.6)
|
||||
XCTAssertEqual(originalParams.ropeFreqScale, 0.7)
|
||||
XCTAssertEqual(originalParams.yarnExtFactor, 0.8)
|
||||
XCTAssertEqual(originalParams.yarnAttnFactor, 0.9)
|
||||
XCTAssertEqual(originalParams.yarnBetaFast, 1.0)
|
||||
XCTAssertEqual(originalParams.yarnBetaSlow, 1.1)
|
||||
XCTAssertEqual(originalParams.yarnOrigCtx, 160)
|
||||
XCTAssertEqual(originalParams.defragThold, 1.2)
|
||||
|
||||
// Verify enums
|
||||
XCTAssertEqual(originalParams.numaStrategy, 1)
|
||||
XCTAssertEqual(originalParams.splitMode, 2)
|
||||
XCTAssertEqual(originalParams.ropeScalingType, 3)
|
||||
XCTAssertEqual(originalParams.poolingType, 4)
|
||||
XCTAssertEqual(originalParams.attentionType, 5)
|
||||
|
||||
// Verify string properties
|
||||
XCTAssertEqual(originalParams.modelPath, "path/to/model")
|
||||
XCTAssertEqual(originalParams.modelDraft, "model_draft")
|
||||
XCTAssertEqual(originalParams.modelAlias, "alias")
|
||||
XCTAssertEqual(originalParams.modelURL, "http://model.url")
|
||||
XCTAssertEqual(originalParams.hfToken, "token")
|
||||
XCTAssertEqual(originalParams.hfRepo, "repo")
|
||||
XCTAssertEqual(originalParams.hfFile, "file")
|
||||
XCTAssertEqual(originalParams.prompt, "prompt")
|
||||
XCTAssertEqual(originalParams.promptFile, "prompt.txt")
|
||||
XCTAssertEqual(originalParams.pathPromptCache, "cache/path")
|
||||
XCTAssertEqual(originalParams.inputPrefix, "prefix")
|
||||
XCTAssertEqual(originalParams.inputSuffix, "suffix")
|
||||
XCTAssertEqual(originalParams.logdir, "log/dir")
|
||||
XCTAssertEqual(originalParams.lookupCacheStatic, "static/cache")
|
||||
XCTAssertEqual(originalParams.lookupCacheDynamic, "dynamic/cache")
|
||||
XCTAssertEqual(originalParams.logitsFile, "logits.txt")
|
||||
XCTAssertEqual(originalParams.rpcServers, "servers")
|
||||
|
||||
// Verify array properties
|
||||
XCTAssertEqual(originalParams.inputFiles, ["input1.txt", "input2.txt"])
|
||||
XCTAssertEqual(originalParams.antiPrompts, ["anti1", "anti2"])
|
||||
XCTAssertEqual(originalParams.kvOverrides as? [String], ["override1", "override2"])
|
||||
// XCTAssertEqual(originalParams.loraAdapters, ["adapter1", "adapter2"])
|
||||
// XCTAssertEqual(originalParams.controlVectors, ["control1", "control2"])
|
||||
|
||||
// Verify boolean and control properties
|
||||
XCTAssertTrue(originalParams.loraInitWithoutApply)
|
||||
XCTAssertEqual(originalParams.verbosity, 1)
|
||||
XCTAssertEqual(originalParams.controlVectorLayerStart, 2)
|
||||
XCTAssertEqual(originalParams.controlVectorLayerEnd, 3)
|
||||
XCTAssertEqual(originalParams.pplStride, 4)
|
||||
XCTAssertEqual(originalParams.pplOutputType, 5)
|
||||
XCTAssertTrue(originalParams.hellaswag)
|
||||
XCTAssertEqual(originalParams.hellaswagTasks, 10)
|
||||
XCTAssertFalse(originalParams.winogrande)
|
||||
XCTAssertEqual(originalParams.winograndeTasks, 20)
|
||||
XCTAssertTrue(originalParams.multipleChoice)
|
||||
XCTAssertEqual(originalParams.multipleChoiceTasks, 30)
|
||||
XCTAssertFalse(originalParams.klDivergence)
|
||||
XCTAssertTrue(originalParams.usage)
|
||||
XCTAssertFalse(originalParams.useColor)
|
||||
XCTAssertTrue(originalParams.special)
|
||||
XCTAssertFalse(originalParams.interactive)
|
||||
XCTAssertTrue(originalParams.interactiveFirst)
|
||||
XCTAssertFalse(originalParams.conversation)
|
||||
XCTAssertTrue(originalParams.promptCacheAll)
|
||||
XCTAssertFalse(originalParams.promptCacheRO)
|
||||
XCTAssertTrue(originalParams.escapeSequences)
|
||||
XCTAssertFalse(originalParams.multilineInput)
|
||||
XCTAssertTrue(originalParams.simpleIO)
|
||||
XCTAssertFalse(originalParams.continuousBatching)
|
||||
XCTAssertTrue(originalParams.flashAttention)
|
||||
XCTAssertFalse(originalParams.noPerformanceMetrics)
|
||||
XCTAssertTrue(originalParams.contextShift)
|
||||
|
||||
// Verify server and I/O settings
|
||||
XCTAssertEqual(originalParams.port, 8080)
|
||||
XCTAssertEqual(originalParams.timeoutRead, 60)
|
||||
XCTAssertEqual(originalParams.timeoutWrite, 30)
|
||||
XCTAssertEqual(originalParams.httpThreads, 4)
|
||||
XCTAssertEqual(originalParams.hostname, "localhost")
|
||||
XCTAssertEqual(originalParams.publicPath, "/public")
|
||||
XCTAssertEqual(originalParams.chatTemplate, "template")
|
||||
XCTAssertEqual(originalParams.systemPrompt, "system prompt")
|
||||
XCTAssertTrue(originalParams.enableChatTemplate)
|
||||
XCTAssertEqual(originalParams.apiKeys, ["key1", "key2"])
|
||||
XCTAssertEqual(originalParams.sslFileKey, "key.pem")
|
||||
XCTAssertEqual(originalParams.sslFileCert, "cert.pem")
|
||||
XCTAssertTrue(originalParams.endpointSlots)
|
||||
XCTAssertFalse(originalParams.endpointMetrics)
|
||||
XCTAssertTrue(originalParams.logJSON)
|
||||
XCTAssertEqual(originalParams.slotSavePath, "/slots")
|
||||
XCTAssertEqual(originalParams.slotPromptSimilarity, 0.75)
|
||||
|
||||
// Verify batched-bench params
|
||||
XCTAssertTrue(originalParams.isPPShared)
|
||||
XCTAssertEqual(originalParams.nPP, [1, 2])
|
||||
XCTAssertEqual(originalParams.nTG, [3, 4])
|
||||
XCTAssertEqual(originalParams.nPL, [5, 6])
|
||||
|
||||
// Verify retrieval params
|
||||
XCTAssertEqual(originalParams.contextFiles, ["context1.txt", "context2.txt"])
|
||||
XCTAssertEqual(originalParams.chunkSize, 1024)
|
||||
XCTAssertEqual(originalParams.chunkSeparator, "\n")
|
||||
|
||||
// Verify passkey params
|
||||
XCTAssertEqual(originalParams.nJunk, 7)
|
||||
XCTAssertEqual(originalParams.iPos, 8)
|
||||
|
||||
// Verify imatrix params
|
||||
XCTAssertEqual(originalParams.outFile, "output.txt")
|
||||
XCTAssertEqual(originalParams.nOutFreq, 100)
|
||||
XCTAssertEqual(originalParams.nSaveFreq, 200)
|
||||
XCTAssertEqual(originalParams.iChunk, 9)
|
||||
XCTAssertTrue(originalParams.processOutput)
|
||||
XCTAssertFalse(originalParams.computePPL)
|
||||
|
||||
// Verify cvector-generator params
|
||||
XCTAssertEqual(originalParams.nPCABatch, 10)
|
||||
XCTAssertEqual(originalParams.nPCAIterations, 11)
|
||||
XCTAssertEqual(originalParams.cvectorDimreMethod, 12)
|
||||
XCTAssertEqual(originalParams.cvectorOutfile, "cvector.out")
|
||||
XCTAssertEqual(originalParams.cvectorPositiveFile, "positive.txt")
|
||||
XCTAssertEqual(originalParams.cvectorNegativeFile, "negative.txt")
|
||||
|
||||
// Verify additional properties
|
||||
XCTAssertTrue(originalParams.spmInfill)
|
||||
XCTAssertEqual(originalParams.loraOutfile, "lora.out")
|
||||
XCTAssertFalse(originalParams.embedding)
|
||||
XCTAssertTrue(originalParams.verbosePrompt)
|
||||
XCTAssertFalse(originalParams.batchedBenchOutputJSONL)
|
||||
XCTAssertTrue(originalParams.inputPrefixBOS)
|
||||
XCTAssertFalse(originalParams.ctxShift)
|
||||
XCTAssertTrue(originalParams.displayPrompt)
|
||||
XCTAssertFalse(originalParams.logging)
|
||||
|
||||
// Test the copy function
|
||||
guard let copiedParams = originalParams.copy() as? GPTParams else {
|
||||
XCTFail("Copy function did not return a GPTParams instance.")
|
||||
return
|
||||
}
|
||||
|
||||
// Verify that the copied properties match the original
|
||||
XCTAssertEqual(copiedParams.nPredict, originalParams.nPredict)
|
||||
XCTAssertEqual(copiedParams.nCtx, originalParams.nCtx)
|
||||
XCTAssertEqual(copiedParams.nBatch, originalParams.nBatch)
|
||||
XCTAssertEqual(copiedParams.nUBatch, originalParams.nUBatch)
|
||||
XCTAssertEqual(copiedParams.nKeep, originalParams.nKeep)
|
||||
XCTAssertEqual(copiedParams.nDraft, originalParams.nDraft)
|
||||
XCTAssertEqual(copiedParams.nChunks, originalParams.nChunks)
|
||||
XCTAssertEqual(copiedParams.nParallel, originalParams.nParallel)
|
||||
XCTAssertEqual(copiedParams.nSequences, originalParams.nSequences)
|
||||
XCTAssertEqual(copiedParams.pSplit, originalParams.pSplit)
|
||||
XCTAssertEqual(copiedParams.nGpuLayers, originalParams.nGpuLayers)
|
||||
XCTAssertEqual(copiedParams.nGpuLayersDraft, originalParams.nGpuLayersDraft)
|
||||
XCTAssertEqual(copiedParams.mainGpu, originalParams.mainGpu)
|
||||
XCTAssertEqual(copiedParams.tensorSplit, originalParams.tensorSplit)
|
||||
XCTAssertEqual(copiedParams.grpAttnN, originalParams.grpAttnN)
|
||||
XCTAssertEqual(copiedParams.grpAttnW, originalParams.grpAttnW)
|
||||
XCTAssertEqual(copiedParams.nPrint, originalParams.nPrint)
|
||||
XCTAssertEqual(copiedParams.ropeFreqBase, originalParams.ropeFreqBase)
|
||||
XCTAssertEqual(copiedParams.ropeFreqScale, originalParams.ropeFreqScale)
|
||||
XCTAssertEqual(copiedParams.yarnExtFactor, originalParams.yarnExtFactor)
|
||||
XCTAssertEqual(copiedParams.yarnAttnFactor, originalParams.yarnAttnFactor)
|
||||
XCTAssertEqual(copiedParams.yarnBetaFast, originalParams.yarnBetaFast)
|
||||
XCTAssertEqual(copiedParams.yarnBetaSlow, originalParams.yarnBetaSlow)
|
||||
XCTAssertEqual(copiedParams.yarnOrigCtx, originalParams.yarnOrigCtx)
|
||||
XCTAssertEqual(copiedParams.defragThold, originalParams.defragThold)
|
||||
XCTAssertEqual(copiedParams.modelPath, originalParams.modelPath)
|
||||
XCTAssertEqual(copiedParams.apiKeys, originalParams.apiKeys)
|
||||
// XCTAssertEqual(copiedParams.controlVectors, originalParams.controlVectors)
|
||||
// Continue verifying all other properties...
|
||||
|
||||
// Verify that modifying the original does not affect the copy
|
||||
originalParams.nPredict = 999
|
||||
XCTAssertNotEqual(copiedParams.nPredict, originalParams.nPredict)
|
||||
}
|
||||
}
|
|
@ -2,27 +2,28 @@ import Foundation
|
|||
import Testing
|
||||
@testable import LlamaKit
|
||||
import JSONSchema
|
||||
import OSLog
|
||||
|
||||
// MARK: LlamaGrammarSession Suite
|
||||
@Suite("LlamaGrammarSession Suite")
|
||||
struct LlamaGrammarSessionSuite {
|
||||
@Suite("LlamaSession Suite")
|
||||
struct LlamaSessionSuite {
|
||||
@JSONSchema struct Trip {
|
||||
let location: String
|
||||
let startDate: TimeInterval
|
||||
let durationInDays: Int
|
||||
}
|
||||
|
||||
func downloadFile() async throws -> String {
|
||||
func downloadFile(url: String, to path: String) async throws -> String {
|
||||
let fm = FileManager.default
|
||||
let tmpDir = fm.temporaryDirectory
|
||||
let destinationURL = tmpDir.appending(path: "tinyllama.gguf")
|
||||
let destinationURL = tmpDir.appending(path: path)
|
||||
|
||||
guard !fm.fileExists(atPath: destinationURL.path()) else {
|
||||
return destinationURL.path()
|
||||
}
|
||||
print("Downloading TinyLlama, this may take a while...")
|
||||
print("Downloading \(path), this may take a while...")
|
||||
// Define the URL
|
||||
guard let url = URL(string: "https://huggingface.co/TheBloke/TinyLlama-1.1B-Chat-v1.0-GGUF/resolve/main/tinyllama-1.1b-chat-v1.0.Q3_K_L.gguf?download=true") else {
|
||||
guard let url = URL(string: url) else {
|
||||
print("Invalid URL.")
|
||||
throw URLError(.badURL)
|
||||
}
|
||||
|
@ -32,22 +33,22 @@ struct LlamaGrammarSessionSuite {
|
|||
|
||||
// Define the destination path in the documents directory
|
||||
|
||||
|
||||
// Move the downloaded file to the destination
|
||||
try fm.moveItem(at: tempURL, to: destinationURL)
|
||||
print("File downloaded to: \(destinationURL.path())")
|
||||
return destinationURL.path()
|
||||
}
|
||||
|
||||
@Test func llamaGrammarSession() async throws {
|
||||
|
||||
func baseParams(url: String, to path: String) async throws -> GPTParams {
|
||||
let params = GPTParams()
|
||||
params.modelPath = try await downloadFile()
|
||||
params.nPredict = 256
|
||||
params.nCtx = 1024
|
||||
params.cpuParams.nThreads = 4
|
||||
params.cpuParamsBatch.nThreads = 4
|
||||
params.modelPath = try await downloadFile(url: url, to: path)
|
||||
params.nPredict = 512
|
||||
params.nCtx = 4096
|
||||
params.cpuParams.nThreads = 8
|
||||
params.cpuParamsBatch.nThreads = 8
|
||||
params.nBatch = 1024
|
||||
params.nGpuLayers = 128
|
||||
params.nGpuLayers = 1024
|
||||
params.chatTemplate = """
|
||||
<|system|>
|
||||
{system_message}</s>
|
||||
|
@ -55,6 +56,28 @@ struct LlamaGrammarSessionSuite {
|
|||
{prompt}</s>
|
||||
<|assistant|>
|
||||
"""
|
||||
params.interactive = true
|
||||
return params
|
||||
}
|
||||
|
||||
@Test func llamaInferenceSession() async throws {
|
||||
let params = try await baseParams(url: "https://huggingface.co/TheBloke/TinyLlama-1.1B-Chat-v1.0-GGUF/resolve/main/tinyllama-1.1b-chat-v1.0.Q8_0.gguf?download=true", to: "tinyllama.gguf")
|
||||
params.prompt = """
|
||||
<|system|>
|
||||
You are an AI assistant. Answer queries simply and concisely.</s>
|
||||
"""
|
||||
params.antiPrompts = ["</s>"]
|
||||
params.inputPrefix = "<|user|>"
|
||||
params.inputSuffix = "</s><|assistant|>"
|
||||
params.interactive = true
|
||||
let session = try await LlamaChatSession(params: params, flush: false)
|
||||
for await msg in await session.inferenceStream(message: "How are you today?") {
|
||||
print(msg, terminator: "")
|
||||
}
|
||||
}
|
||||
|
||||
@Test func llamaGrammarSession() async throws {
|
||||
let params = try await baseParams(url: "https://huggingface.co/TheBloke/TinyLlama-1.1B-Chat-v1.0-GGUF/resolve/main/tinyllama-1.1b-chat-v1.0.Q8_0.gguf?download=true", to: "tinyllama.gguf")
|
||||
params.prompt = """
|
||||
You are a travel agent. The current date epoch \(Date.now.timeIntervalSince1970).
|
||||
Responses should have the following fields:
|
||||
|
@ -64,7 +87,6 @@ struct LlamaGrammarSessionSuite {
|
|||
durationInDays: the duration of the trip in days
|
||||
|
||||
"""
|
||||
params.interactive = true
|
||||
let session = try await LlamaSession<Trip>(params: params)
|
||||
await #expect(throws: Never.self) {
|
||||
let trip = try await session.chat(message: "Please create a trip for me to New York City that starts two weeks from now. The duration of the trip MUST be 3 days long.")
|
||||
|
@ -73,6 +95,25 @@ struct LlamaGrammarSessionSuite {
|
|||
// TODO: so for now, we are just asserting the grammar works
|
||||
}
|
||||
}
|
||||
|
||||
@JSONSchema struct IsCorrect {
|
||||
let isSpellingCorrect: Bool
|
||||
}
|
||||
|
||||
@Test func llamaSimpleGrammarSession() async throws {
|
||||
let params = try await baseParams(url: "https://huggingface.co/RichardErkhov/openfoodfacts_-_spellcheck-mistral-7b-gguf/resolve/main/spellcheck-mistral-7b.Q8_0.gguf?download=true",
|
||||
to: "spellcheck_q8.gguf")
|
||||
params.prompt = """
|
||||
###You are a spell checker. I will provide you with the word 'strawberry'. If the spelling of the given word is correct, please respond {"isCorrect": true} else respond {"isCorrect": false}.\n
|
||||
"""
|
||||
let session = try await LlamaSession<IsCorrect>(params: params)
|
||||
for _ in 0..<10 {
|
||||
var output = try await session.chat(message: "###strawberry\n")
|
||||
#expect(output.isSpellingCorrect)
|
||||
output = try await session.chat(message: "###strawberrry\n")
|
||||
#expect(!output.isSpellingCorrect)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
import WeatherKit
|
||||
|
@ -115,7 +156,7 @@ func downloadFile() async throws -> String {
|
|||
/// Get the current weather in a given location.
|
||||
/// - parameter location: The city and state, e.g. San Francisco, CA
|
||||
/// - parameter unit: The unit of temperature
|
||||
@Tool public static func getCurrentWeather(location: String, unit: String) async throws -> CurrentWeather {
|
||||
@Tool public func getCurrentWeather(location: String, unit: String) async throws -> CurrentWeather {
|
||||
let weather = try await WeatherService().weather(for: CLGeocoder().geocodeAddressString(location)[0].location!)
|
||||
var temperature = weather.currentWeather.temperature
|
||||
temperature.convert(to: .fahrenheit)
|
||||
|
@ -134,7 +175,7 @@ func downloadFile() async throws -> String {
|
|||
params.nBatch = 1024
|
||||
params.nGpuLayers = 1024
|
||||
let llama = try await MyLlama(params: params)
|
||||
let currentWeather = try await MyLlama.getCurrentWeather(location: "San Francisco, CA", unit: "farenheit")
|
||||
let currentWeather = try await llama.getCurrentWeather(location: "San Francisco, CA", unit: "farenheit")
|
||||
let output = try await llama.chat("What's the weather (in farenheit) in San Francisco, CA?")
|
||||
#expect(output.contains(String(format: "%d", Int(currentWeather.temperature))))
|
||||
// #expect(output.contains(currentWeather.condition.rawValue))
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue