add arg flag, not working on embedding mode

This commit is contained in:
strikingLoo 2023-03-18 23:34:20 -07:00
parent b97df76c54
commit 801071ec4f
3 changed files with 75 additions and 51 deletions

122
main.cpp
View file

@ -519,6 +519,17 @@ bool llama_model_load(const std::string & fname, llama_model & model, gpt_vocab
return true;
}
// Prints the provided embedding vector to stdout
// in a neat format
void display_embedding(const std::vector<float> & embedding_representation){
fprintf(stdout, "\n[\n");
for (int j = 0; j < embedding_representation.size()-1 ; j++){
fprintf(stdout, "%f, ", embedding_representation[j]);
}
fprintf(stdout, "%f", embedding_representation[embedding_representation.size()-1]);
fprintf(stdout, "\n]\n");
}
// evaluate the transformer
//
// - model: the model
@ -535,7 +546,8 @@ bool llama_eval(
const int n_past,
const std::vector<gpt_vocab::id> & embd_inp,
std::vector<float> & embd_w,
size_t & mem_per_token) {
size_t & mem_per_token,
const bool embeding_mode) {
const int N = embd_inp.size();
const auto & hparams = model.hparams;
@ -720,56 +732,52 @@ bool llama_eval(
ggml_repeat(ctx0, model.norm, inpL),
inpL);
}
// run the computation
ggml_build_forward_expand(&gf, inpL);
ggml_graph_compute (ctx0, &gf);
// capture input sentence embedding
{
std::vector<float> embedding_representation;
embedding_representation.resize(n_embd);
memcpy(embedding_representation.data(), (float *) ggml_get_data(inpL) + (n_embd * (N - 2)), sizeof(float) * n_embd);
fprintf(stdout, "\n[\n");
for (int j = 0; j < embedding_representation.size()-1 ; j++){
fprintf(stdout, "%f, ", embedding_representation[j]);
if(!embeding_mode){
// lm_head
{
inpL = ggml_mul_mat(ctx0, model.output, inpL);
}
fprintf(stdout, "%f", embedding_representation[embedding_representation.size()-1]);
fprintf(stdout, "\n]\n");
// logits -> probs
//inpL = ggml_soft_max(ctx0, inpL);
// run the computation
ggml_build_forward_expand(&gf, inpL);
ggml_graph_compute (ctx0, &gf);
//if (n_past%100 == 0) {
// ggml_graph_print (&gf);
// ggml_graph_dump_dot(&gf, NULL, "gpt-2.dot");
//}
//embd_w.resize(n_vocab*N);
//memcpy(embd_w.data(), ggml_get_data(inpL), sizeof(float)*n_vocab*N);
// return result for just the last token
embd_w.resize(n_vocab);
memcpy(embd_w.data(), (float *) ggml_get_data(inpL) + (n_vocab*(N-1)), sizeof(float)*n_vocab);
if (mem_per_token == 0) {
mem_per_token = ggml_used_mem(ctx0)/N;
}
//fprintf(stderr, "used_mem = %zu\n", ggml_used_mem(ctx0));
ggml_free(ctx0);
return true;
} else {
// capture input sentence embedding
ggml_build_forward_expand(&gf, inpL);
ggml_graph_compute (ctx0, &gf);
printf("Compute went ok\n");
std::vector<float> embedding_representation;
embedding_representation.resize(n_embd);
memcpy(embedding_representation.data(), (float *) ggml_get_data(inpL) + (n_embd * (N - 1)), sizeof(float) * n_embd);
printf("About to display\n");
display_embedding(embedding_representation);
printf("About to free\n");
ggml_free(ctx0);
return true;
}
// lm_head
{
inpL = ggml_mul_mat(ctx0, model.output, inpL);
}
// logits -> probs
//inpL = ggml_soft_max(ctx0, inpL);
//if (n_past%100 == 0) {
// ggml_graph_print (&gf);
// ggml_graph_dump_dot(&gf, NULL, "gpt-2.dot");
//}
//embd_w.resize(n_vocab*N);
//memcpy(embd_w.data(), ggml_get_data(inpL), sizeof(float)*n_vocab*N);
// return result for just the last token
embd_w.resize(n_vocab);
memcpy(embd_w.data(), (float *) ggml_get_data(inpL) + (n_vocab*(N-1)), sizeof(float)*n_vocab);
if (mem_per_token == 0) {
mem_per_token = ggml_used_mem(ctx0)/N;
}
//fprintf(stderr, "used_mem = %zu\n", ggml_used_mem(ctx0));
ggml_free(ctx0);
return true;
}
static bool is_interacting = false;
@ -906,13 +914,12 @@ int main(int argc, char ** argv) {
// determine the required inference memory per token:
size_t mem_per_token = 0;
llama_eval(model, params.n_threads, 0, { 0, 1, 2, 3 }, logits, mem_per_token);
llama_eval(model, params.n_threads, 0, { 0, 1, 2, 3 }, logits, mem_per_token, false);
int last_n_size = params.repeat_last_n;
std::vector<gpt_vocab::id> last_n_tokens(last_n_size);
std::fill(last_n_tokens.begin(), last_n_tokens.end(), 0);
if (params.interactive) {
fprintf(stderr, "== Running in interactive mode. ==\n"
#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) || defined (_WIN32)
@ -936,12 +943,27 @@ int main(int argc, char ** argv) {
printf(ANSI_COLOR_YELLOW);
}
if (params.embedding){
printf("got right before second call.\n");
const int64_t t_start_us = ggml_time_us(); //HERE
if (!llama_eval(model, params.n_threads, n_past, embd, logits, mem_per_token, true)) {
fprintf(stderr, "Failed to predict\n");
return 1;
}
//ggml_free(model.ctx);
if (params.use_color) {
printf(ANSI_COLOR_RESET);
}
return 0;
}
while (remaining_tokens > 0) {
// predict
if (embd.size() > 0) {
const int64_t t_start_us = ggml_time_us();
if (!llama_eval(model, params.n_threads, n_past, embd, logits, mem_per_token)) {
if (!llama_eval(model, params.n_threads, n_past, embd, logits, mem_per_token, false)) {
fprintf(stderr, "Failed to predict\n");
return 1;
}

View file

@ -53,6 +53,8 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
params.model = argv[++i];
} else if (arg == "-i" || arg == "--interactive") {
params.interactive = true;
} else if (arg == "--embedding") {
params.embedding = true;
} else if (arg == "--interactive-start") {
params.interactive = true;
params.interactive_start = true;

View file

@ -31,7 +31,7 @@ struct gpt_params {
std::string prompt;
bool use_color = false; // use color to distinguish generations and inputs
bool embedding = false; // get only sentence embedding
bool interactive = false; // interactive mode
bool interactive_start = false; // reverse prompt immediately
std::string antiprompt = ""; // string upon seeing which more user input is prompted