Steering
This commit is contained in:
parent
63d20469b8
commit
021e6d9944
5 changed files with 100 additions and 0 deletions
|
@ -344,6 +344,30 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
|
|||
break;
|
||||
}
|
||||
params.input_suffix = argv[i];
|
||||
} else if (arg == "--steering-add") {
|
||||
if (++i >= argc) {
|
||||
invalid_param = true;
|
||||
break;
|
||||
}
|
||||
params.steering_add = argv[i];
|
||||
} else if (arg == "--steering-sub") {
|
||||
if (++i >= argc) {
|
||||
invalid_param = true;
|
||||
break;
|
||||
}
|
||||
params.steering_sub = argv[i];
|
||||
} else if (arg == "--steering-mul") {
|
||||
if (++i >= argc) {
|
||||
invalid_param = true;
|
||||
break;
|
||||
}
|
||||
params.steering_mul = std::stof(argv[i]);
|
||||
} else if (arg == "--steering-lyr") {
|
||||
if (++i >= argc) {
|
||||
invalid_param = true;
|
||||
break;
|
||||
}
|
||||
params.steering_lyr = std::stoi(argv[i]);
|
||||
} else {
|
||||
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
|
||||
gpt_print_usage(argc, argv, default_params);
|
||||
|
|
|
@ -72,6 +72,11 @@ struct gpt_params {
|
|||
bool use_mlock = false; // use mlock to keep model in memory
|
||||
bool mem_test = false; // compute maximum memory usage
|
||||
bool verbose_prompt = false; // print prompt tokens before generation
|
||||
|
||||
std::string steering_add = "";
|
||||
std::string steering_sub = "";
|
||||
float steering_mul = 1.0f;
|
||||
int steering_lyr = 20;
|
||||
};
|
||||
|
||||
bool gpt_params_parse(int argc, char ** argv, gpt_params & params);
|
||||
|
|
|
@ -136,6 +136,28 @@ int main(int argc, char ** argv) {
|
|||
return 0;
|
||||
}
|
||||
|
||||
if (params.steering_add.size() || params.steering_sub.size())
|
||||
{
|
||||
auto steering_add_tokens = ::llama_tokenize(ctx, params.steering_add, true);
|
||||
auto steering_sub_tokens = ::llama_tokenize(ctx, params.steering_sub, true);
|
||||
|
||||
if (steering_add_tokens.size() != steering_sub_tokens.size()) {
|
||||
llama_token space;
|
||||
llama_tokenize(ctx, " ", &space, 1, 0);
|
||||
|
||||
while (steering_add_tokens.size() < steering_sub_tokens.size()) steering_add_tokens.push_back(space);
|
||||
while (steering_sub_tokens.size() < steering_add_tokens.size()) steering_sub_tokens.push_back(space);
|
||||
}
|
||||
|
||||
llama_set_steering_write(ctx, params.steering_lyr, params.steering_mul/2);
|
||||
llama_eval(ctx, steering_add_tokens.data(), std::min((int)steering_add_tokens.size(), params.n_ctx), 0, params.n_threads);
|
||||
|
||||
llama_set_steering_write(ctx, params.steering_lyr, -params.steering_mul/2);
|
||||
llama_eval(ctx, steering_sub_tokens.data(), std::min((int)steering_sub_tokens.size(), params.n_ctx), 0, params.n_threads);
|
||||
|
||||
llama_set_steering_read(ctx, params.steering_lyr, 1);
|
||||
}
|
||||
|
||||
// Add a space in front of the first character to match OG llama tokenizer behavior
|
||||
params.prompt.insert(0, 1, ' ');
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue