cleanup and stuff

This commit is contained in:
Henri Vasserman 2023-05-16 15:16:00 +03:00
parent 021e6d9944
commit 8388aaa604
No known key found for this signature in database
GPG key ID: 2995FC0F58B1A986
4 changed files with 48 additions and 37 deletions

View file

@ -136,28 +136,6 @@ int main(int argc, char ** argv) {
return 0;
}
if (params.steering_add.size() || params.steering_sub.size())
{
auto steering_add_tokens = ::llama_tokenize(ctx, params.steering_add, true);
auto steering_sub_tokens = ::llama_tokenize(ctx, params.steering_sub, true);
if (steering_add_tokens.size() != steering_sub_tokens.size()) {
llama_token space;
llama_tokenize(ctx, " ", &space, 1, 0);
while (steering_add_tokens.size() < steering_sub_tokens.size()) steering_add_tokens.push_back(space);
while (steering_sub_tokens.size() < steering_add_tokens.size()) steering_sub_tokens.push_back(space);
}
llama_set_steering_write(ctx, params.steering_lyr, params.steering_mul/2);
llama_eval(ctx, steering_add_tokens.data(), std::min((int)steering_add_tokens.size(), params.n_ctx), 0, params.n_threads);
llama_set_steering_write(ctx, params.steering_lyr, -params.steering_mul/2);
llama_eval(ctx, steering_sub_tokens.data(), std::min((int)steering_sub_tokens.size(), params.n_ctx), 0, params.n_threads);
llama_set_steering_read(ctx, params.steering_lyr, 1);
}
// Add a space in front of the first character to match OG llama tokenizer behavior
params.prompt.insert(0, 1, ' ');
@ -196,6 +174,32 @@ int main(int argc, char ** argv) {
return 1;
}
if (!params.steering_add.empty() || !params.steering_sub.empty())
{
params.steering_add.insert(0, 1, ' ');
params.steering_sub.insert(0, 1, ' ');
auto add_tokens = ::llama_tokenize(ctx, params.steering_add, true);
auto sub_tokens = ::llama_tokenize(ctx, params.steering_sub, true);
//if (add_tokens.size() != sub_tokens.size()) {
// while (add_tokens.size() < sub_tokens.size()) {
// add_tokens.push_back(llama_token_nl());
// }
// while (sub_tokens.size() < add_tokens.size()) {
// sub_tokens.push_back(llama_token_nl());
// }
//}
//const int N = embd_inp.size();
llama_set_steering_write(ctx, params.steering_layer, +1.0f);
llama_eval(ctx, add_tokens.data(), std::min((int)add_tokens.size(), n_ctx), 0, params.n_threads);
llama_set_steering_write(ctx, params.steering_layer, -1.0f);
llama_eval(ctx, sub_tokens.data(), std::min((int)sub_tokens.size(), n_ctx), 0, params.n_threads);
llama_set_steering_read(ctx, params.steering_layer, params.steering_mul);
}
// debug message about similarity of saved session, if applicable
size_t n_matching_session_tokens = 0;
if (session_tokens.size()) {