add arg flag, not working on embedding mode
This commit is contained in:
parent
b97df76c54
commit
801071ec4f
3 changed files with 75 additions and 51 deletions
122
main.cpp
122
main.cpp
|
@ -519,6 +519,17 @@ bool llama_model_load(const std::string & fname, llama_model & model, gpt_vocab
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Prints the provided embedding vector to stdout
|
||||||
|
// in a neat format
|
||||||
|
void display_embedding(const std::vector<float> & embedding_representation){
|
||||||
|
fprintf(stdout, "\n[\n");
|
||||||
|
for (int j = 0; j < embedding_representation.size()-1 ; j++){
|
||||||
|
fprintf(stdout, "%f, ", embedding_representation[j]);
|
||||||
|
}
|
||||||
|
fprintf(stdout, "%f", embedding_representation[embedding_representation.size()-1]);
|
||||||
|
fprintf(stdout, "\n]\n");
|
||||||
|
}
|
||||||
|
|
||||||
// evaluate the transformer
|
// evaluate the transformer
|
||||||
//
|
//
|
||||||
// - model: the model
|
// - model: the model
|
||||||
|
@ -535,7 +546,8 @@ bool llama_eval(
|
||||||
const int n_past,
|
const int n_past,
|
||||||
const std::vector<gpt_vocab::id> & embd_inp,
|
const std::vector<gpt_vocab::id> & embd_inp,
|
||||||
std::vector<float> & embd_w,
|
std::vector<float> & embd_w,
|
||||||
size_t & mem_per_token) {
|
size_t & mem_per_token,
|
||||||
|
const bool embeding_mode) {
|
||||||
const int N = embd_inp.size();
|
const int N = embd_inp.size();
|
||||||
|
|
||||||
const auto & hparams = model.hparams;
|
const auto & hparams = model.hparams;
|
||||||
|
@ -720,56 +732,52 @@ bool llama_eval(
|
||||||
ggml_repeat(ctx0, model.norm, inpL),
|
ggml_repeat(ctx0, model.norm, inpL),
|
||||||
inpL);
|
inpL);
|
||||||
}
|
}
|
||||||
|
|
||||||
// run the computation
|
if(!embeding_mode){
|
||||||
ggml_build_forward_expand(&gf, inpL);
|
// lm_head
|
||||||
ggml_graph_compute (ctx0, &gf);
|
{
|
||||||
|
inpL = ggml_mul_mat(ctx0, model.output, inpL);
|
||||||
// capture input sentence embedding
|
|
||||||
{
|
|
||||||
std::vector<float> embedding_representation;
|
|
||||||
embedding_representation.resize(n_embd);
|
|
||||||
memcpy(embedding_representation.data(), (float *) ggml_get_data(inpL) + (n_embd * (N - 2)), sizeof(float) * n_embd);
|
|
||||||
fprintf(stdout, "\n[\n");
|
|
||||||
for (int j = 0; j < embedding_representation.size()-1 ; j++){
|
|
||||||
fprintf(stdout, "%f, ", embedding_representation[j]);
|
|
||||||
}
|
}
|
||||||
fprintf(stdout, "%f", embedding_representation[embedding_representation.size()-1]);
|
// logits -> probs
|
||||||
fprintf(stdout, "\n]\n");
|
//inpL = ggml_soft_max(ctx0, inpL);
|
||||||
|
|
||||||
|
// run the computation
|
||||||
|
ggml_build_forward_expand(&gf, inpL);
|
||||||
|
ggml_graph_compute (ctx0, &gf);
|
||||||
|
|
||||||
|
//if (n_past%100 == 0) {
|
||||||
|
// ggml_graph_print (&gf);
|
||||||
|
// ggml_graph_dump_dot(&gf, NULL, "gpt-2.dot");
|
||||||
|
//}
|
||||||
|
|
||||||
|
//embd_w.resize(n_vocab*N);
|
||||||
|
//memcpy(embd_w.data(), ggml_get_data(inpL), sizeof(float)*n_vocab*N);
|
||||||
|
|
||||||
|
// return result for just the last token
|
||||||
|
embd_w.resize(n_vocab);
|
||||||
|
memcpy(embd_w.data(), (float *) ggml_get_data(inpL) + (n_vocab*(N-1)), sizeof(float)*n_vocab);
|
||||||
|
if (mem_per_token == 0) {
|
||||||
|
mem_per_token = ggml_used_mem(ctx0)/N;
|
||||||
|
}
|
||||||
|
//fprintf(stderr, "used_mem = %zu\n", ggml_used_mem(ctx0));
|
||||||
|
|
||||||
|
ggml_free(ctx0);
|
||||||
|
return true;
|
||||||
|
} else {
|
||||||
|
// capture input sentence embedding
|
||||||
|
ggml_build_forward_expand(&gf, inpL);
|
||||||
|
ggml_graph_compute (ctx0, &gf);
|
||||||
|
printf("Compute went ok\n");
|
||||||
|
std::vector<float> embedding_representation;
|
||||||
|
embedding_representation.resize(n_embd);
|
||||||
|
memcpy(embedding_representation.data(), (float *) ggml_get_data(inpL) + (n_embd * (N - 1)), sizeof(float) * n_embd);
|
||||||
|
printf("About to display\n");
|
||||||
|
display_embedding(embedding_representation);
|
||||||
|
printf("About to free\n");
|
||||||
|
ggml_free(ctx0);
|
||||||
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
// lm_head
|
|
||||||
{
|
|
||||||
inpL = ggml_mul_mat(ctx0, model.output, inpL);
|
|
||||||
}
|
|
||||||
|
|
||||||
// logits -> probs
|
|
||||||
//inpL = ggml_soft_max(ctx0, inpL);
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
//if (n_past%100 == 0) {
|
|
||||||
// ggml_graph_print (&gf);
|
|
||||||
// ggml_graph_dump_dot(&gf, NULL, "gpt-2.dot");
|
|
||||||
//}
|
|
||||||
|
|
||||||
//embd_w.resize(n_vocab*N);
|
|
||||||
//memcpy(embd_w.data(), ggml_get_data(inpL), sizeof(float)*n_vocab*N);
|
|
||||||
|
|
||||||
// return result for just the last token
|
|
||||||
embd_w.resize(n_vocab);
|
|
||||||
memcpy(embd_w.data(), (float *) ggml_get_data(inpL) + (n_vocab*(N-1)), sizeof(float)*n_vocab);
|
|
||||||
|
|
||||||
if (mem_per_token == 0) {
|
|
||||||
mem_per_token = ggml_used_mem(ctx0)/N;
|
|
||||||
}
|
|
||||||
//fprintf(stderr, "used_mem = %zu\n", ggml_used_mem(ctx0));
|
|
||||||
|
|
||||||
ggml_free(ctx0);
|
|
||||||
|
|
||||||
return true;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
static bool is_interacting = false;
|
static bool is_interacting = false;
|
||||||
|
@ -906,13 +914,12 @@ int main(int argc, char ** argv) {
|
||||||
|
|
||||||
// determine the required inference memory per token:
|
// determine the required inference memory per token:
|
||||||
size_t mem_per_token = 0;
|
size_t mem_per_token = 0;
|
||||||
llama_eval(model, params.n_threads, 0, { 0, 1, 2, 3 }, logits, mem_per_token);
|
llama_eval(model, params.n_threads, 0, { 0, 1, 2, 3 }, logits, mem_per_token, false);
|
||||||
|
|
||||||
int last_n_size = params.repeat_last_n;
|
int last_n_size = params.repeat_last_n;
|
||||||
std::vector<gpt_vocab::id> last_n_tokens(last_n_size);
|
std::vector<gpt_vocab::id> last_n_tokens(last_n_size);
|
||||||
std::fill(last_n_tokens.begin(), last_n_tokens.end(), 0);
|
std::fill(last_n_tokens.begin(), last_n_tokens.end(), 0);
|
||||||
|
|
||||||
|
|
||||||
if (params.interactive) {
|
if (params.interactive) {
|
||||||
fprintf(stderr, "== Running in interactive mode. ==\n"
|
fprintf(stderr, "== Running in interactive mode. ==\n"
|
||||||
#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) || defined (_WIN32)
|
#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) || defined (_WIN32)
|
||||||
|
@ -936,12 +943,27 @@ int main(int argc, char ** argv) {
|
||||||
printf(ANSI_COLOR_YELLOW);
|
printf(ANSI_COLOR_YELLOW);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (params.embedding){
|
||||||
|
printf("got right before second call.\n");
|
||||||
|
const int64_t t_start_us = ggml_time_us(); //HERE
|
||||||
|
if (!llama_eval(model, params.n_threads, n_past, embd, logits, mem_per_token, true)) {
|
||||||
|
fprintf(stderr, "Failed to predict\n");
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
//ggml_free(model.ctx);
|
||||||
|
|
||||||
|
if (params.use_color) {
|
||||||
|
printf(ANSI_COLOR_RESET);
|
||||||
|
}
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
while (remaining_tokens > 0) {
|
while (remaining_tokens > 0) {
|
||||||
// predict
|
// predict
|
||||||
if (embd.size() > 0) {
|
if (embd.size() > 0) {
|
||||||
const int64_t t_start_us = ggml_time_us();
|
const int64_t t_start_us = ggml_time_us();
|
||||||
|
|
||||||
if (!llama_eval(model, params.n_threads, n_past, embd, logits, mem_per_token)) {
|
if (!llama_eval(model, params.n_threads, n_past, embd, logits, mem_per_token, false)) {
|
||||||
fprintf(stderr, "Failed to predict\n");
|
fprintf(stderr, "Failed to predict\n");
|
||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
|
|
|
@ -53,6 +53,8 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
|
||||||
params.model = argv[++i];
|
params.model = argv[++i];
|
||||||
} else if (arg == "-i" || arg == "--interactive") {
|
} else if (arg == "-i" || arg == "--interactive") {
|
||||||
params.interactive = true;
|
params.interactive = true;
|
||||||
|
} else if (arg == "--embedding") {
|
||||||
|
params.embedding = true;
|
||||||
} else if (arg == "--interactive-start") {
|
} else if (arg == "--interactive-start") {
|
||||||
params.interactive = true;
|
params.interactive = true;
|
||||||
params.interactive_start = true;
|
params.interactive_start = true;
|
||||||
|
|
2
utils.h
2
utils.h
|
@ -31,7 +31,7 @@ struct gpt_params {
|
||||||
std::string prompt;
|
std::string prompt;
|
||||||
|
|
||||||
bool use_color = false; // use color to distinguish generations and inputs
|
bool use_color = false; // use color to distinguish generations and inputs
|
||||||
|
bool embedding = false; // get only sentence embedding
|
||||||
bool interactive = false; // interactive mode
|
bool interactive = false; // interactive mode
|
||||||
bool interactive_start = false; // reverse prompt immediately
|
bool interactive_start = false; // reverse prompt immediately
|
||||||
std::string antiprompt = ""; // string upon seeing which more user input is prompted
|
std::string antiprompt = ""; // string upon seeing which more user input is prompted
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue