refactor the interface and fixed the styles

This commit is contained in:
ningshanwutuobang 2023-06-25 01:21:54 +08:00
parent 53dfbbf553
commit 9d866118c4
11 changed files with 335 additions and 442 deletions

View file

@ -1,5 +1,5 @@
# Define the default target now so that it is always the first target # Define the default target now so that it is always the first target
BUILD_TARGETS = main quantize quantize-stats perplexity embedding vdot train-text-from-scratch simple libembd_input.so embd_input_test BUILD_TARGETS = main quantize quantize-stats perplexity embedding vdot train-text-from-scratch simple libembdinput.so embd-input-test
ifdef LLAMA_BUILD_SERVER ifdef LLAMA_BUILD_SERVER
BUILD_TARGETS += server BUILD_TARGETS += server
@ -302,12 +302,12 @@ save-load-state: examples/save-load-state/save-load-state.cpp build-info.h ggml.
server: examples/server/server.cpp examples/server/httplib.h examples/server/json.hpp build-info.h ggml.o llama.o common.o $(OBJS) server: examples/server/server.cpp examples/server/httplib.h examples/server/json.hpp build-info.h ggml.o llama.o common.o $(OBJS)
$(CXX) $(CXXFLAGS) -Iexamples/server $(filter-out %.h,$(filter-out %.hpp,$^)) -o $@ $(LDFLAGS) $(CXX) $(CXXFLAGS) -Iexamples/server $(filter-out %.h,$(filter-out %.hpp,$^)) -o $@ $(LDFLAGS)
libembd_input.so: examples/embd_input/embd_input.h examples/embd_input/embd_input_lib.cpp examples/embd_input/embd_input_test.cpp build-info.h ggml.o llama.o common.o $(OBJS) libembdinput.so: examples/embd-input/embd-input.h examples/embd-input/embd-input-lib.cpp build-info.h ggml.o llama.o common.o $(OBJS)
$(CXX) --shared $(CXXFLAGS) -Iexamples/server $(filter-out %.h,$(filter-out %.hpp,$^)) -o $@ $(LDFLAGS) $(CXX) --shared $(CXXFLAGS) $(filter-out %.h,$(filter-out %.hpp,$^)) -o $@ $(LDFLAGS)
embd_input_test: libembd_input.so examples/embd_input/embd_input_test.cpp build-info.h ggml.o llama.o common.o $(OBJS) embd-input-test: libembdinput.so examples/embd-input/embd-input-test.cpp build-info.h ggml.o llama.o common.o $(OBJS)
$(CXX) $(CXXFLAGS) -Iexamples/server $(filter-out %.so,$(filter-out %.h,$(filter-out %.hpp,$^))) -o $@ $(LDFLAGS) -L. -Wl,-rpath=./ -lembd_input $(CXX) $(CXXFLAGS) $(filter-out %.so,$(filter-out %.h,$(filter-out %.hpp,$^))) -o $@ $(LDFLAGS) -L. -lembdinput
train-text-from-scratch: examples/train-text-from-scratch/train-text-from-scratch.cpp build-info.h ggml.o llama.o $(OBJS) train-text-from-scratch: examples/train-text-from-scratch/train-text-from-scratch.cpp build-info.h ggml.o llama.o $(OBJS)
$(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS) $(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS)

View file

@ -0,0 +1,218 @@
// Defines sigaction on msys:
#ifndef _GNU_SOURCE
#define _GNU_SOURCE
#endif
#include "embd-input.h"
#include <cassert>
#include <cinttypes>
#include <cmath>
#include <cstdio>
#include <cstring>
#include <ctime>
#include <fstream>
#include <iostream>
#include <string>
#include <vector>
static llama_context ** g_ctx;
extern "C" {
struct MyModel* create_mymodel(int argc, char ** argv) {
gpt_params params;
if (gpt_params_parse(argc, argv, params) == false) {
return nullptr;
}
fprintf(stderr, "%s: build = %d (%s)\n", __func__, BUILD_NUMBER, BUILD_COMMIT);
if (params.seed < 0) {
params.seed = time(NULL);
}
fprintf(stderr, "%s: seed = %d\n", __func__, params.seed);
llama_init_backend();
llama_context * ctx;
g_ctx = &ctx;
// load the model and apply lora adapter, if any
ctx = llama_init_from_gpt_params(params);
if (ctx == NULL) {
fprintf(stderr, "%s: error: unable to load model\n", __func__);
return nullptr;
}
// print system information
{
fprintf(stderr, "\n");
fprintf(stderr, "system_info: n_threads = %d / %d | %s\n",
params.n_threads, std::thread::hardware_concurrency(), llama_print_system_info());
}
struct MyModel * ret = new MyModel();
ret->ctx = ctx;
ret->params = params;
ret->n_past = 0;
// printf("ctx: %d\n", ret->ctx);
return ret;
}
void free_mymodel(struct MyModel * mymodel) {
llama_context * ctx = mymodel->ctx;
llama_print_timings(ctx);
llama_free(ctx);
delete mymodel;
}
bool eval_float(void * model, float * input, int N){
MyModel * mymodel = (MyModel*)model;
llama_context * ctx = mymodel->ctx;
gpt_params params = mymodel->params;
int n_emb = llama_n_embd(ctx);
int n_past = mymodel->n_past;
int n_batch = N; // params.n_batch;
for (int i = 0; i < (int) N; i += n_batch) {
int n_eval = (int) N - i;
if (n_eval > n_batch) {
n_eval = n_batch;
}
if (llama_eval_embd(ctx, (input+i*n_emb), n_eval, n_past, params.n_threads)) {
fprintf(stderr, "%s : failed to eval\n", __func__);
return false;
}
n_past += n_eval;
}
mymodel->n_past = n_past;
return true;
}
bool eval_tokens(void * model, std::vector<llama_token> tokens) {
MyModel * mymodel = (MyModel* )model;
llama_context * ctx;
ctx = mymodel->ctx;
gpt_params params = mymodel->params;
int n_past = mymodel->n_past;
for (int i = 0; i < (int) tokens.size(); i += params.n_batch) {
int n_eval = (int) tokens.size() - i;
if (n_eval > params.n_batch) {
n_eval = params.n_batch;
}
if (llama_eval(ctx, &tokens[i], n_eval, n_past, params.n_threads)) {
fprintf(stderr, "%s : failed to eval\n", __func__);
return false;
}
n_past += n_eval;
}
mymodel->n_past = n_past;
return true;
}
bool eval_id(struct MyModel* mymodel, int id) {
std::vector<llama_token> tokens;
tokens.push_back(id);
return eval_tokens(mymodel, tokens);
}
bool eval_string(struct MyModel * mymodel,const char* str){
llama_context * ctx = mymodel->ctx;
std::string str2 = str;
std::vector<llama_token> embd_inp = ::llama_tokenize(ctx, str2, true);
eval_tokens(mymodel, embd_inp);
return true;
}
llama_token sampling_id(struct MyModel* mymodel) {
llama_context* ctx = mymodel->ctx;
gpt_params params = mymodel->params;
// int n_ctx = llama_n_ctx(ctx);
// out of user input, sample next token
const float temp = params.temp;
const int32_t top_k = params.top_k <= 0 ? llama_n_vocab(ctx) : params.top_k;
const float top_p = params.top_p;
const float tfs_z = params.tfs_z;
const float typical_p = params.typical_p;
// const int32_t repeat_last_n = params.repeat_last_n < 0 ? n_ctx : params.repeat_last_n;
// const float repeat_penalty = params.repeat_penalty;
// const float alpha_presence = params.presence_penalty;
// const float alpha_frequency = params.frequency_penalty;
const int mirostat = params.mirostat;
const float mirostat_tau = params.mirostat_tau;
const float mirostat_eta = params.mirostat_eta;
// const bool penalize_nl = params.penalize_nl;
llama_token id = 0;
{
auto logits = llama_get_logits(ctx);
auto n_vocab = llama_n_vocab(ctx);
// Apply params.logit_bias map
for (auto it = params.logit_bias.begin(); it != params.logit_bias.end(); it++) {
logits[it->first] += it->second;
}
std::vector<llama_token_data> candidates;
candidates.reserve(n_vocab);
for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
candidates.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f});
}
llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
// TODO: Apply penalties
// float nl_logit = logits[llama_token_nl()];
// auto last_n_repeat = std::min(std::min((int)last_n_tokens.size(), repeat_last_n), n_ctx);
// llama_sample_repetition_penalty(ctx, &candidates_p,
// last_n_tokens.data() + last_n_tokens.size() - last_n_repeat,
// last_n_repeat, repeat_penalty);
// llama_sample_frequency_and_presence_penalties(ctx, &candidates_p,
// last_n_tokens.data() + last_n_tokens.size() - last_n_repeat,
// last_n_repeat, alpha_frequency, alpha_presence);
// if (!penalize_nl) {
// logits[llama_token_nl()] = nl_logit;
// }
if (temp <= 0) {
// Greedy sampling
id = llama_sample_token_greedy(ctx, &candidates_p);
} else {
if (mirostat == 1) {
static float mirostat_mu = 2.0f * mirostat_tau;
const int mirostat_m = 100;
llama_sample_temperature(ctx, &candidates_p, temp);
id = llama_sample_token_mirostat(ctx, &candidates_p, mirostat_tau, mirostat_eta, mirostat_m, &mirostat_mu);
} else if (mirostat == 2) {
static float mirostat_mu = 2.0f * mirostat_tau;
llama_sample_temperature(ctx, &candidates_p, temp);
id = llama_sample_token_mirostat_v2(ctx, &candidates_p, mirostat_tau, mirostat_eta, &mirostat_mu);
} else {
// Temperature sampling
llama_sample_top_k(ctx, &candidates_p, top_k, 1);
llama_sample_tail_free(ctx, &candidates_p, tfs_z, 1);
llama_sample_typical(ctx, &candidates_p, typical_p, 1);
llama_sample_top_p(ctx, &candidates_p, top_p, 1);
llama_sample_temperature(ctx, &candidates_p, temp);
id = llama_sample_token(ctx, &candidates_p);
}
}
}
return id;
}
const char * sampling(struct MyModel * mymodel) {
llama_context * ctx = mymodel->ctx;
int id = sampling_id(mymodel);
std::string ret;
if (id == llama_token_eos()) ret = "</s>";
else ret = llama_token_to_str(ctx, id);
eval_id(mymodel, id);
return ret.c_str();
}
}

View file

@ -1,4 +1,4 @@
#include "embd_input.h" #include "embd-input.h"
#include <stdlib.h> #include <stdlib.h>
#include <random> #include <random>
#include <string.h> #include <string.h>
@ -7,7 +7,10 @@ int main(int argc, char** argv) {
auto mymodel = create_mymodel(argc, argv); auto mymodel = create_mymodel(argc, argv);
int N = 10; int N = 10;
int max_tgt_len = 500;
int n_embd = llama_n_embd(mymodel->ctx); int n_embd = llama_n_embd(mymodel->ctx);
// add random float embd to test evaluation
float * data = new float[N*n_embd]; float * data = new float[N*n_embd];
std::default_random_engine e; std::default_random_engine e;
std::uniform_real_distribution<float> u(0,1); std::uniform_real_distribution<float> u(0,1);
@ -16,19 +19,15 @@ int main(int argc, char** argv) {
} }
eval_string(mymodel, "user: what is the color of the flag of UN?"); eval_string(mymodel, "user: what is the color of the flag of UN?");
// printf("eval float");
eval_float(mymodel, data, N); eval_float(mymodel, data, N);
eval_string(mymodel, "assistant:"); eval_string(mymodel, "assistant:");
// printf("eval float end\n");
eval_string(mymodel, mymodel->params.prompt.c_str()); eval_string(mymodel, mymodel->params.prompt.c_str());
const char* tmp; const char* tmp;
for (int i=0;i < 500; i++) { for (int i=0; i<max_tgt_len; i++) {
// int id = sampling_id(mymodel);
tmp = sampling(mymodel); tmp = sampling(mymodel);
if (strcmp(tmp, "</s>")==0) break; if (strcmp(tmp, "</s>")==0) break;
printf("%s", tmp); // llama_token_to_str(mymodel->ctx, id)); printf("%s", tmp);
fflush(stdout); fflush(stdout);
// eval_id(mymodel, id);
} }
printf("\n"); printf("\n");
free_mymodel(mymodel); free_mymodel(mymodel);

View file

@ -1,8 +1,9 @@
import ctypes import ctypes
from ctypes import cdll, c_char_p, c_void_p, POINTER, c_float, c_int from ctypes import cdll, c_char_p, c_void_p, POINTER, c_float, c_int
import numpy as np import numpy as np
import os
libc = cdll.LoadLibrary("./libembd_input.so") libc = cdll.LoadLibrary("./libembdinput.so")
libc.sampling.restype=c_char_p libc.sampling.restype=c_char_p
libc.create_mymodel.restype=c_void_p libc.create_mymodel.restype=c_void_p
libc.eval_string.argtypes=[c_void_p, c_char_p] libc.eval_string.argtypes=[c_void_p, c_char_p]
@ -16,7 +17,9 @@ class MyModel:
c_str = [c_char_p(i.encode()) for i in args] c_str = [c_char_p(i.encode()) for i in args]
args_c = (c_char_p * argc)(*c_str) args_c = (c_char_p * argc)(*c_str)
self.model = c_void_p(libc.create_mymodel(argc, args_c)) self.model = c_void_p(libc.create_mymodel(argc, args_c))
# print("self.model", self.model) self.max_tgt_len = 512
self.print_string_eval = True
def __del__(self): def __del__(self):
libc.free_mymodel(self.model) libc.free_mymodel(self.model)
@ -25,6 +28,8 @@ class MyModel:
def eval_string(self, x): def eval_string(self, x):
libc.eval_string(self.model, x.encode()) # c_char_p(x.encode())) libc.eval_string(self.model, x.encode()) # c_char_p(x.encode()))
if self.print_string_eval:
print(x)
def eval_token(self, x): def eval_token(self, x):
libc.eval_id(self.model, x) libc.eval_id(self.model, x)
@ -33,49 +38,34 @@ class MyModel:
s = libc.sampling(self.model) s = libc.sampling(self.model)
return s return s
def generate(self, end="</s>"):
ret = b""
end = end.encode()
for _ in range(500):
tmp = self.sampling() # .decode()
if (ret+tmp).endswith(end):
break
ret += tmp
return ret.decode()
def stream_generate(self, end="</s>"): def stream_generate(self, end="</s>"):
ret = b"" ret = b""
end = end.encode() end = end.encode()
head = b"" for _ in range(self.max_tgt_len):
for _ in range(500): tmp = self.sampling()
tmp = self.sampling() # .decode()
ret += tmp ret += tmp
try: yield tmp
text = (head + tmp).decode()
print(text, end="")
head = b""
except:
head += text
if ret.endswith(end): if ret.endswith(end):
break break
print("")
return ret.decode()
def generate_with_print(self, end="</s>"):
ret = b""
for i in self.stream_generate(end=end):
ret += i
print(i.decode(errors="replace"), end="", flush=True)
print("")
return ret.decode(errors="replace")
def generate(self, end="</s>"):
text = b"".join(self.stream_generate(end=end))
return text.decode(errors="replace")
if __name__ == "__main__": if __name__ == "__main__":
model = MyModel(["main", "--model", "../llama.cpp/models/ggml-vic13b-q4_1.bin", "-c", "2048"]) model = MyModel(["main", "--model", "../llama.cpp/models/ggml-vic13b-q4_1.bin", "-c", "2048"])
# print(model)
model.eval_string("""user: what is the color of the flag of UN?""") model.eval_string("""user: what is the color of the flag of UN?""")
# model.eval_token(100)
x = np.random.random((5120,10))# , dtype=np.float32) x = np.random.random((5120,10))# , dtype=np.float32)
model.eval_float(x) model.eval_float(x)
model.eval_string("""assistant:""") model.eval_string("""assistant:""")
# print(x[0,0], x[0,1],x[1,0]) for i in model.generate():
# model.eval_float(x) print(i.decode(errors="replace"), end="", flush=True)
# print(libc)
for i in range(500):
tmp = model.sampling().decode()
if tmp == "":
break
print(tmp, end="", flush=True)

View file

@ -31,7 +31,7 @@ class Llava:
self.model.eval_string("user: ") self.model.eval_string("user: ")
self.model.eval_string(question) self.model.eval_string(question)
self.model.eval_string("\nassistant: ") self.model.eval_string("\nassistant: ")
return self.model.generate() return self.model.generate_with_print()
def chat_with_image(self, image, question): def chat_with_image(self, image, question):
with torch.no_grad(): with torch.no_grad():
@ -49,7 +49,7 @@ class Llava:
self.model.eval_token(32003-1) # im_end self.model.eval_token(32003-1) # im_end
self.model.eval_string(question) self.model.eval_string(question)
self.model.eval_string("\nassistant: ") self.model.eval_string("\nassistant: ")
return self.model.generate() return self.model.generate_with_print()
if __name__=="__main__": if __name__=="__main__":
@ -63,8 +63,8 @@ if __name__=="__main__":
respose = a.chat_with_image( respose = a.chat_with_image(
Image.open("./media/llama1-logo.png").convert('RGB'), Image.open("./media/llama1-logo.png").convert('RGB'),
"what is the text in the picture?") "what is the text in the picture?")
print(respose) respose
print(a.chat("what is the color of it?")) a.chat("what is the color of it?")

View file

@ -7,11 +7,13 @@ from torch import nn
import torch import torch
# use PandaGPT path # use PandaGPT path
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "PandaGPT","code","model")) panda_gpt_path = os.path.join(os.path.dirname(__file__), "PandaGPT")
imagebind_ckpt_path = "./models/panda_gpt/"
sys.path.insert(0, os.path.join(panda_gpt_path,"code","model"))
from ImageBind.models import imagebind_model from ImageBind.models import imagebind_model
from ImageBind import data from ImageBind import data
imagebind_ckpt_path = "./models/panda_gpt/"
ModalityType = imagebind_model.ModalityType ModalityType = imagebind_model.ModalityType
max_tgt_len = 400 max_tgt_len = 400
@ -31,25 +33,25 @@ class PandaGPT:
"weight": state["llama_proj.weight"], "weight": state["llama_proj.weight"],
"bias": state["llama_proj.bias"]}) "bias": state["llama_proj.bias"]})
def eval_inputs(self, inputs):
self.model.eval_string("<Img>")
embds = self.extract_multimoal_feature(inputs)
for i in embds:
self.model.eval_float(i.T)
self.model.eval_string("</Img> ")
def chat(self, question): def chat(self, question):
if self.generated_text == "": return self.chat_with_image(None, question)
self.model.eval_string("###")
self.model.eval_string(" Human: ")
self.model.eval_string(question)
self.model.eval_string("\n### Assistant:")
ret = self.model.stream_generate(end="###")
self.generated_text += ret
return ret
def chat_with_image(self, inputs, question): def chat_with_image(self, inputs, question):
if self.generated_text == "": if self.generated_text == "":
self.model.eval_string("###") self.model.eval_string("###")
self.model.eval_string(" Human: <Img>") self.model.eval_string(" Human: ")
embds = self.extract_multimoal_feature(inputs) if inputs:
for i in embds: self.eval_inputs(inputs)
self.model.eval_float(i.T) self.model.eval_string(question)
self.model.eval_string("</Img> " + question + "\n### Assistant:") self.model.eval_string("\n### Assistant:")
ret = self.model.stream_generate(end="###") ret = self.model.generate_with_print(end="###")
self.generated_text += ret self.generated_text += ret
return ret return ret
@ -88,13 +90,9 @@ class PandaGPT:
if __name__=="__main__": if __name__=="__main__":
# model form liuhaotian/LLaVA-13b-delta-v1-1
a = PandaGPT(["--model", "./models/ggml-vicuna-13b-v0-q4_1.bin", "-c", "2048", "--lora", "./models/panda_gpt/ggml-adapter-model.bin","--temp", "0"]) a = PandaGPT(["--model", "./models/ggml-vicuna-13b-v0-q4_1.bin", "-c", "2048", "--lora", "./models/panda_gpt/ggml-adapter-model.bin","--temp", "0"])
# Extract from https://huggingface.co/liuhaotian/LLaVA-13b-delta-v1-1/blob/main/pytorch_model-00003-of-00003.bin.
# Also here can use pytorch_model-00003-of-00003.bin directly.
a.load_projection("./models/panda_gpt/adapter_model.bin") a.load_projection("./models/panda_gpt/adapter_model.bin")
a.chat_with_image( a.chat_with_image(
{"image_paths": ["./media/llama1-logo.png"]}, {"image_paths": ["./media/llama1-logo.png"]},
"what is the text in the picture? 'llama' or 'lambda'?") "what is the text in the picture? 'llama' or 'lambda'?")
a.chat("what is the color of it?") a.chat("what is the color of it?")

View file

@ -1,283 +0,0 @@
// Defines sigaction on msys:
#ifndef _GNU_SOURCE
#define _GNU_SOURCE
#endif
#include "embd_input.h"
#include <cassert>
#include <cinttypes>
#include <cmath>
#include <cstdio>
#include <cstring>
#include <ctime>
#include <fstream>
#include <iostream>
#include <string>
#include <vector>
#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__))
#include <signal.h>
#include <unistd.h>
#elif defined (_WIN32)
#define WIN32_LEAN_AND_MEAN
#define NOMINMAX
#include <windows.h>
#include <signal.h>
#endif
static console_state con_st;
static llama_context ** g_ctx;
static bool is_interacting = false;
#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) || defined (_WIN32)
void sigint_handler(int signo) {
if (signo == SIGINT) {
if (!is_interacting) {
is_interacting=true;
} else {
console_cleanup(con_st);
printf("\n");
llama_print_timings(*g_ctx);
_exit(130);
}
}
}
#endif
extern "C" {
struct MyModel* create_mymodel(int argc, char ** argv) {
gpt_params params;
if (gpt_params_parse(argc, argv, params) == false) {
return nullptr;
}
if (params.n_ctx > 2048) {
fprintf(stderr, "%s: warning: model does not support context sizes greater than 2048 tokens (%d specified);"
"expect poor results\n", __func__, params.n_ctx);
}
fprintf(stderr, "%s: build = %d (%s)\n", __func__, BUILD_NUMBER, BUILD_COMMIT);
if (params.seed < 0) {
params.seed = time(NULL);
}
fprintf(stderr, "%s: seed = %d\n", __func__, params.seed);
std::mt19937 rng(params.seed);
if (params.random_prompt) {
params.prompt = gpt_random_prompt(rng);
}
llama_init_backend();
llama_context * ctx;
g_ctx = &ctx;
// load the model and apply lora adapter, if any
ctx = llama_init_from_gpt_params(params);
if (ctx == NULL) {
fprintf(stderr, "%s: error: unable to load model\n", __func__);
return nullptr;
}
// print system information
{
fprintf(stderr, "\n");
fprintf(stderr, "system_info: n_threads = %d / %d | %s\n",
params.n_threads, std::thread::hardware_concurrency(), llama_print_system_info());
}
struct MyModel* ret= new MyModel();
ret->ctx = ctx;
ret->params = params;
ret->n_past = 0;
// printf("ctx: %d\n", ret->ctx);
return ret;
}
void free_mymodel(struct MyModel* mymodel) {
llama_context* ctx = mymodel->ctx;
llama_print_timings(ctx);
llama_free(ctx);
delete mymodel;
}
bool eval_float(void* model, float* input, int N){
MyModel* mymodel = (MyModel* )model;
llama_context* ctx = mymodel->ctx;
gpt_params params = mymodel->params;
int n_emb = llama_n_embd(ctx);
int n_past = mymodel->n_past;
// printf("%f,%f\n", *input, *(input+1));
int n_batch = N; // params.n_batch;
for (int i = 0; i < (int) N; i += n_batch) {
int n_eval = (int) N - i;
if (n_eval > n_batch) {
n_eval = n_batch;
}
if (llama_eval_float(ctx, (input+i*n_emb), n_eval, n_past, params.n_threads)) {
fprintf(stderr, "%s : failed to eval\n", __func__);
return false;
}
n_past += n_eval;
}
mymodel->n_past = n_past;
return true;
}
bool eval_tokens(void* model, std::vector<llama_token> tokens) {
MyModel* mymodel = (MyModel* )model;
// printf("model: %d\n", mymodel);
llama_context* ctx;// = mymodel->ctx;
// printf("ctx2: %d\n", ctx);
// printf("ctx2: %d\n", mymodel->ctx);
ctx = mymodel->ctx;
// printf("ctx2: %d\n", ctx);
gpt_params params = mymodel->params;
// printf("\n%d\n", params);
int n_past = mymodel->n_past;
for (int i = 0; i < (int) tokens.size(); i += params.n_batch) {
int n_eval = (int) tokens.size() - i;
if (n_eval > params.n_batch) {
n_eval = params.n_batch;
}
// printf("%d, %d, %d\n", i, n_eval, n_past);
if (llama_eval(ctx, &tokens[i], n_eval, n_past, params.n_threads)) {
fprintf(stderr, "%s : failed to eval\n", __func__);
return false;
}
n_past += n_eval;
}
mymodel->n_past = n_past;
return true;
}
bool eval_id(struct MyModel* mymodel, int id) {
// printf("%d\n", id);
std::vector<llama_token> tokens;
tokens.push_back(id);
// printf("%d\n", tokens.size());
// printf("%d\n", tokens[0]);
return eval_tokens(mymodel, tokens);
}
bool eval_string(struct MyModel* mymodel,const char* str){
// std::cout << "eval " << std::endl;
// printf("%s", str);
llama_context* ctx = mymodel->ctx;
std::string str2 = str;
// printf("%s", str2.c_str());
std::cout << str2 << std::endl;
std::vector<llama_token> embd_inp = ::llama_tokenize(ctx, str2, true);
eval_tokens(mymodel, embd_inp);
return true;
}
llama_token sampling_id(struct MyModel* mymodel) {
llama_context* ctx = mymodel->ctx;
gpt_params params = mymodel->params;
// int n_ctx = llama_n_ctx(ctx);
// out of user input, sample next token
const float temp = params.temp;
const int32_t top_k = params.top_k <= 0 ? llama_n_vocab(ctx) : params.top_k;
const float top_p = params.top_p;
const float tfs_z = params.tfs_z;
const float typical_p = params.typical_p;
// const int32_t repeat_last_n = params.repeat_last_n < 0 ? n_ctx : params.repeat_last_n;
// const float repeat_penalty = params.repeat_penalty;
// const float alpha_presence = params.presence_penalty;
// const float alpha_frequency = params.frequency_penalty;
const int mirostat = params.mirostat;
const float mirostat_tau = params.mirostat_tau;
const float mirostat_eta = params.mirostat_eta;
// const bool penalize_nl = params.penalize_nl;
llama_token id = 0;
{
auto logits = llama_get_logits(ctx);
auto n_vocab = llama_n_vocab(ctx);
// Apply params.logit_bias map
for (auto it = params.logit_bias.begin(); it != params.logit_bias.end(); it++) {
logits[it->first] += it->second;
}
std::vector<llama_token_data> candidates;
candidates.reserve(n_vocab);
for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
candidates.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f});
}
llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
// Apply penalties
// float nl_logit = logits[llama_token_nl()];
// auto last_n_repeat = std::min(std::min((int)last_n_tokens.size(), repeat_last_n), n_ctx);
// llama_sample_repetition_penalty(ctx, &candidates_p,
// last_n_tokens.data() + last_n_tokens.size() - last_n_repeat,
// last_n_repeat, repeat_penalty);
// llama_sample_frequency_and_presence_penalties(ctx, &candidates_p,
// last_n_tokens.data() + last_n_tokens.size() - last_n_repeat,
// last_n_repeat, alpha_frequency, alpha_presence);
// if (!penalize_nl) {
// logits[llama_token_nl()] = nl_logit;
// }
if (temp <= 0) {
// Greedy sampling
id = llama_sample_token_greedy(ctx, &candidates_p);
} else {
if (mirostat == 1) {
static float mirostat_mu = 2.0f * mirostat_tau;
const int mirostat_m = 100;
llama_sample_temperature(ctx, &candidates_p, temp);
id = llama_sample_token_mirostat(ctx, &candidates_p, mirostat_tau, mirostat_eta, mirostat_m, &mirostat_mu);
} else if (mirostat == 2) {
static float mirostat_mu = 2.0f * mirostat_tau;
llama_sample_temperature(ctx, &candidates_p, temp);
id = llama_sample_token_mirostat_v2(ctx, &candidates_p, mirostat_tau, mirostat_eta, &mirostat_mu);
} else {
// Temperature sampling
llama_sample_top_k(ctx, &candidates_p, top_k, 1);
llama_sample_tail_free(ctx, &candidates_p, tfs_z, 1);
llama_sample_typical(ctx, &candidates_p, typical_p, 1);
llama_sample_top_p(ctx, &candidates_p, top_p, 1);
llama_sample_temperature(ctx, &candidates_p, temp);
id = llama_sample_token(ctx, &candidates_p);
}
}
}
return id;
}
const char* sampling(struct MyModel* mymodel) {
llama_context* ctx = mymodel->ctx;
int id = sampling_id(mymodel);
std::string ret;
if (id == llama_token_eos()) ret = "</s>";
else ret = llama_token_to_str(ctx, id);
eval_id(mymodel, id);
return ret.c_str();
}
}

125
llama.cpp
View file

@ -1342,15 +1342,33 @@ static bool llama_model_load(
} }
} }
static bool llama_eval_internal_tensor( // evaluate the transformer
//
// - lctx: llama context
// - tokens: new batch of tokens to process
// - n_tokens number of tokens
// - embd embeddings input
// - n_past: the context size so far
// - n_threads: number of threads to use
//
static bool llama_eval_internal(
llama_context & lctx, llama_context & lctx,
ggml_context* ctx0, const llama_token * tokens,
ggml_tensor* inpL,
const int n_tokens, const int n_tokens,
const float * embd,
const int n_past, const int n_past,
const int n_threads, const int n_threads,
const char * cgraph_fname, const char * cgraph_fname) {
const int64_t t_start_us) {
LLAMA_ASSERT((!tokens && embd) || (tokens && !embd));
// enforce that the first token is BOS
if (tokens && n_past == 0 && tokens[0] != llama_token_bos()) {
fprintf(stderr, "%s: first token must be BOS\n", __func__);
return false;
}
const int64_t t_start_us = ggml_time_us();
const int N = n_tokens; const int N = n_tokens;
@ -1359,7 +1377,6 @@ static bool llama_eval_internal_tensor(
const auto & kv_self = model.kv_self; const auto & kv_self = model.kv_self;
LLAMA_ASSERT(!!kv_self.ctx); LLAMA_ASSERT(!!kv_self.ctx);
const int n_embd = hparams.n_embd; const int n_embd = hparams.n_embd;
@ -1371,6 +1388,15 @@ static bool llama_eval_internal_tensor(
const int n_gpu_layers = model.n_gpu_layers; const int n_gpu_layers = model.n_gpu_layers;
auto & mem_per_token = lctx.mem_per_token; auto & mem_per_token = lctx.mem_per_token;
auto & buf_compute = lctx.buf_compute;
struct ggml_init_params params = {
/*.mem_size =*/ buf_compute.size,
/*.mem_buffer =*/ buf_compute.addr,
/*.no_alloc =*/ false,
};
struct ggml_context * ctx0 = ggml_init(params);
// for big prompts, if BLAS is enabled, it is better to use only one thread // for big prompts, if BLAS is enabled, it is better to use only one thread
// otherwise, the threads are spin-lock waiting for the BLAS calls and are degrading the performance // otherwise, the threads are spin-lock waiting for the BLAS calls and are degrading the performance
@ -1378,6 +1404,17 @@ static bool llama_eval_internal_tensor(
gf.n_threads = N >= 32 && ggml_cpu_has_blas() && !ggml_cpu_has_gpublas() ? 1 : n_threads; gf.n_threads = N >= 32 && ggml_cpu_has_blas() && !ggml_cpu_has_gpublas() ? 1 : n_threads;
struct ggml_tensor * cur; struct ggml_tensor * cur;
struct ggml_tensor * inpL;
if (tokens) {
struct ggml_tensor * embd = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
ggml_set_name(embd, "embd");
memcpy(embd->data, tokens, N*ggml_element_size(embd));
inpL = ggml_get_rows(ctx0, model.tok_embeddings, embd);
} else {
inpL = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, N);
memcpy(inpL->data, embd, N * n_embd * ggml_element_size(inpL));
}
const int i_gpu_start = n_layer - n_gpu_layers; const int i_gpu_start = n_layer - n_gpu_layers;
(void) i_gpu_start; (void) i_gpu_start;
@ -1746,53 +1783,6 @@ static bool llama_eval_internal_tensor(
return true; return true;
} }
// evaluate the transformer
//
// - lctx: llama context
// - tokens: new batch of tokens to process
// - n_past: the context size so far
// - n_threads: number of threads to use
//
static bool llama_eval_internal(
llama_context & lctx,
const llama_token * tokens,
const int n_tokens,
const int n_past,
const int n_threads,
const char * cgraph_fname) {
// enforce that the first token is BOS
if (n_past == 0 && tokens[0] != llama_token_bos()) {
fprintf(stderr, "%s: first token must be BOS\n", __func__);
return false;
}
const auto & model = lctx.model;
const int64_t t_start_us = ggml_time_us();
const int N = n_tokens;
auto & buf_compute = lctx.buf_compute;
struct ggml_init_params params = {
/*.mem_size =*/ buf_compute.size,
/*.mem_buffer =*/ buf_compute.addr,
/*.no_alloc =*/ false,
};
struct ggml_context * ctx0 = ggml_init(params);
struct ggml_tensor * embd = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
ggml_set_name(embd, "embd");
memcpy(embd->data, tokens, N*ggml_element_size(embd));
struct ggml_tensor * inpL = ggml_get_rows(ctx0, model.tok_embeddings, embd);
return llama_eval_internal_tensor(lctx, ctx0, inpL, N, n_past, n_threads, cgraph_fname, t_start_us);
}
// //
// tokenizer // tokenizer
// //
@ -3357,7 +3347,7 @@ int llama_eval(
int n_tokens, int n_tokens,
int n_past, int n_past,
int n_threads) { int n_threads) {
if (!llama_eval_internal(*ctx, tokens, n_tokens, n_past, n_threads, nullptr)) { if (!llama_eval_internal(*ctx, tokens, n_tokens, nullptr, n_past, n_threads, nullptr)) {
fprintf(stderr, "%s: failed to eval\n", __func__); fprintf(stderr, "%s: failed to eval\n", __func__);
return 1; return 1;
} }
@ -3373,32 +3363,13 @@ int llama_eval(
} }
int llama_eval_float( int llama_eval_embd(
struct llama_context * ctx, struct llama_context * ctx,
const float * input, const float * embd,
int n_tokens, int n_tokens,
int n_past, int n_past,
int n_threads) { int n_threads) {
const auto & model = ctx->model; if (!llama_eval_internal(*ctx, nullptr, n_tokens, embd, n_past, n_threads, nullptr)) {
const int64_t t_start_us = ggml_time_us();
const int N = n_tokens;
auto & buf_compute = ctx->buf_compute;
struct ggml_init_params params = {
/*.mem_size =*/ buf_compute.size,
/*.mem_buffer =*/ buf_compute.addr,
/*.no_alloc =*/ false,
};
struct ggml_context * ctx0 = ggml_init(params);
struct ggml_tensor *inpL =
ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, model.hparams.n_embd, N);
memcpy(inpL->data, input, N * model.hparams.n_embd * ggml_element_size(inpL));
if (!llama_eval_internal_tensor(*ctx, ctx0, inpL, N, n_past, n_threads, nullptr, t_start_us)) {
fprintf(stderr, "%s: failed to eval\n", __func__); fprintf(stderr, "%s: failed to eval\n", __func__);
return 1; return 1;
} }
@ -3419,7 +3390,7 @@ int llama_eval_export(struct llama_context * ctx, const char * fname) {
const std::vector<llama_token> tmp(n_batch, llama_token_bos()); const std::vector<llama_token> tmp(n_batch, llama_token_bos());
if (!llama_eval_internal(*ctx, tmp.data(), tmp.size(), n_ctx, 1, fname)) { if (!llama_eval_internal(*ctx, tmp.data(), tmp.size(), nullptr, n_ctx, 1, fname)) {
fprintf(stderr, "%s: failed to eval\n", __func__); fprintf(stderr, "%s: failed to eval\n", __func__);
return 1; return 1;
} }

View file

@ -200,9 +200,9 @@ extern "C" {
int n_threads); int n_threads);
// Same as llama_eval, but use float matrix input directly. // Same as llama_eval, but use float matrix input directly.
LLAMA_API int llama_eval_float( LLAMA_API int llama_eval_embd(
struct llama_context * ctx, struct llama_context * ctx,
const float * embds, const float * embd,
int n_tokens, int n_tokens,
int n_past, int n_past,
int n_threads); int n_threads);