Merge branch 'ggerganov:master' into master
This commit is contained in:
commit
1d22550f87
21 changed files with 939 additions and 253 deletions
3
.gitignore
vendored
3
.gitignore
vendored
|
@ -1,5 +1,6 @@
|
|||
*.o
|
||||
*.a
|
||||
*.so
|
||||
.DS_Store
|
||||
.build/
|
||||
.cache/
|
||||
|
@ -39,8 +40,8 @@ models/*
|
|||
/vdot
|
||||
/server
|
||||
/Pipfile
|
||||
/embd-input-test
|
||||
/libllama.so
|
||||
|
||||
build-info.h
|
||||
arm_neon.h
|
||||
compile_commands.json
|
||||
|
|
11
Makefile
11
Makefile
|
@ -1,5 +1,5 @@
|
|||
# Define the default target now so that it is always the first target
|
||||
BUILD_TARGETS = main quantize quantize-stats perplexity embedding vdot train-text-from-scratch simple
|
||||
BUILD_TARGETS = main quantize quantize-stats perplexity embedding vdot train-text-from-scratch simple libembdinput.so embd-input-test
|
||||
|
||||
ifdef LLAMA_BUILD_SERVER
|
||||
BUILD_TARGETS += server
|
||||
|
@ -272,7 +272,7 @@ libllama.so: llama.o ggml.o $(OBJS)
|
|||
$(CXX) $(CXXFLAGS) -shared -fPIC -o $@ $^ $(LDFLAGS)
|
||||
|
||||
clean:
|
||||
rm -vf *.o *.so main quantize quantize-stats perplexity embedding benchmark-matmult save-load-state server vdot train-text-from-scratch build-info.h
|
||||
rm -vf *.o *.so main quantize quantize-stats perplexity embedding benchmark-matmult save-load-state server vdot train-text-from-scratch embd-input-test build-info.h
|
||||
|
||||
#
|
||||
# Examples
|
||||
|
@ -305,6 +305,13 @@ save-load-state: examples/save-load-state/save-load-state.cpp build-info.h ggml.
|
|||
server: examples/server/server.cpp examples/server/httplib.h examples/server/json.hpp build-info.h ggml.o llama.o common.o $(OBJS)
|
||||
$(CXX) $(CXXFLAGS) -Iexamples/server $(filter-out %.h,$(filter-out %.hpp,$^)) -o $@ $(LDFLAGS)
|
||||
|
||||
libembdinput.so: examples/embd-input/embd-input.h examples/embd-input/embd-input-lib.cpp build-info.h ggml.o llama.o common.o $(OBJS)
|
||||
$(CXX) --shared $(CXXFLAGS) $(filter-out %.h,$(filter-out %.hpp,$^)) -o $@ $(LDFLAGS)
|
||||
|
||||
|
||||
embd-input-test: libembdinput.so examples/embd-input/embd-input-test.cpp build-info.h ggml.o llama.o common.o $(OBJS)
|
||||
$(CXX) $(CXXFLAGS) $(filter-out %.so,$(filter-out %.h,$(filter-out %.hpp,$^))) -o $@ $(LDFLAGS) -L. -lembdinput
|
||||
|
||||
train-text-from-scratch: examples/train-text-from-scratch/train-text-from-scratch.cpp build-info.h ggml.o llama.o $(OBJS)
|
||||
$(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS)
|
||||
|
||||
|
|
|
@ -113,6 +113,10 @@ with open(output_path, "wb") as fout:
|
|||
|
||||
write_file_header(fout, params)
|
||||
for k, v in model.items():
|
||||
if k.endswith(".default.weight"):
|
||||
k = k.replace(".default.weight", ".weight")
|
||||
if k in ["llama_proj.weight", "llama_proj.bias"]:
|
||||
continue
|
||||
if k.endswith("lora_A.weight"):
|
||||
if v.dtype != torch.float16 and v.dtype != torch.float32:
|
||||
v = v.float()
|
||||
|
@ -120,7 +124,7 @@ with open(output_path, "wb") as fout:
|
|||
else:
|
||||
v = v.float()
|
||||
|
||||
t = v.numpy()
|
||||
t = v.detach().numpy()
|
||||
tname = translate_tensor_name(k)
|
||||
print(f"{k} => {tname} {t.shape} {t.dtype} {t.nbytes/1024/1024:.2f}MB")
|
||||
write_tensor_header(fout, tname, t.shape, t.dtype)
|
||||
|
|
|
@ -39,6 +39,7 @@ else()
|
|||
add_subdirectory(baby-llama)
|
||||
add_subdirectory(train-text-from-scratch)
|
||||
add_subdirectory(simple)
|
||||
add_subdirectory(embd-input)
|
||||
if (LLAMA_METAL)
|
||||
add_subdirectory(metal)
|
||||
endif()
|
||||
|
|
|
@ -566,8 +566,8 @@ struct ggml_tensor * forward(
|
|||
// wk shape [n_embd, n_embd, 1, 1]
|
||||
// Qcur shape [n_embd/n_head, n_head, N, 1]
|
||||
// Kcur shape [n_embd/n_head, n_head, N, 1]
|
||||
struct ggml_tensor * Qcur = ggml_rope(ctx0, ggml_reshape_3d(ctx0, ggml_mul_mat(ctx0, model->layers[il].wq, cur), n_embd/n_head, n_head, N), n_past, n_rot, 0);
|
||||
struct ggml_tensor * Kcur = ggml_rope(ctx0, ggml_reshape_3d(ctx0, ggml_mul_mat(ctx0, model->layers[il].wk, cur), n_embd/n_head, n_head, N), n_past, n_rot, 0);
|
||||
struct ggml_tensor * Qcur = ggml_rope(ctx0, ggml_reshape_3d(ctx0, ggml_mul_mat(ctx0, model->layers[il].wq, cur), n_embd/n_head, n_head, N), n_past, n_rot, 0, 0);
|
||||
struct ggml_tensor * Kcur = ggml_rope(ctx0, ggml_reshape_3d(ctx0, ggml_mul_mat(ctx0, model->layers[il].wk, cur), n_embd/n_head, n_head, N), n_past, n_rot, 0, 0);
|
||||
|
||||
// store key and value to memory
|
||||
{
|
||||
|
@ -823,8 +823,8 @@ struct ggml_tensor * forward_batch(
|
|||
// wk shape [n_embd, n_embd, 1, 1]
|
||||
// Qcur shape [n_embd/n_head, n_head, N, n_batch]
|
||||
// Kcur shape [n_embd/n_head, n_head, N, n_batch]
|
||||
struct ggml_tensor * Qcur = ggml_rope(ctx0, ggml_reshape_4d(ctx0, ggml_mul_mat(ctx0, model->layers[il].wq, cur), n_embd/n_head, n_head, N, n_batch), n_past, n_rot, 0);
|
||||
struct ggml_tensor * Kcur = ggml_rope(ctx0, ggml_reshape_4d(ctx0, ggml_mul_mat(ctx0, model->layers[il].wk, cur), n_embd/n_head, n_head, N, n_batch), n_past, n_rot, 0);
|
||||
struct ggml_tensor * Qcur = ggml_rope(ctx0, ggml_reshape_4d(ctx0, ggml_mul_mat(ctx0, model->layers[il].wq, cur), n_embd/n_head, n_head, N, n_batch), n_past, n_rot, 0, 0);
|
||||
struct ggml_tensor * Kcur = ggml_rope(ctx0, ggml_reshape_4d(ctx0, ggml_mul_mat(ctx0, model->layers[il].wk, cur), n_embd/n_head, n_head, N, n_batch), n_past, n_rot, 0, 0);
|
||||
assert_shape_4d(Qcur, n_embd/n_head, n_head, N, n_batch);
|
||||
assert_shape_4d(Kcur, n_embd/n_head, n_head, N, n_batch);
|
||||
|
||||
|
@ -1116,7 +1116,7 @@ struct ggml_tensor * forward_lora(
|
|||
model->layers[il].wqb,
|
||||
cur)),
|
||||
n_embd/n_head, n_head, N),
|
||||
n_past, n_rot, 0);
|
||||
n_past, n_rot, 0, 0);
|
||||
struct ggml_tensor * Kcur = ggml_rope(ctx0,
|
||||
ggml_reshape_3d(ctx0,
|
||||
ggml_mul_mat(ctx0,
|
||||
|
@ -1125,7 +1125,7 @@ struct ggml_tensor * forward_lora(
|
|||
model->layers[il].wkb,
|
||||
cur)),
|
||||
n_embd/n_head, n_head, N),
|
||||
n_past, n_rot, 0);
|
||||
n_past, n_rot, 0, 0);
|
||||
|
||||
// store key and value to memory
|
||||
{
|
||||
|
|
|
@ -416,13 +416,6 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
|
|||
exit(1);
|
||||
}
|
||||
|
||||
#ifdef GGML_USE_CUBLAS
|
||||
if (!params.lora_adapter.empty() && params.n_gpu_layers > 0) {
|
||||
fprintf(stderr, "%s: error: the simultaneous use of LoRAs and GPU acceleration is not supported", __func__);
|
||||
exit(1);
|
||||
}
|
||||
#endif // GGML_USE_CUBLAS
|
||||
|
||||
if (escape_prompt) {
|
||||
process_escapes(params.prompt);
|
||||
}
|
||||
|
|
4
examples/embd-input/.gitignore
vendored
Normal file
4
examples/embd-input/.gitignore
vendored
Normal file
|
@ -0,0 +1,4 @@
|
|||
PandaGPT
|
||||
MiniGPT-4
|
||||
*.pth
|
||||
|
15
examples/embd-input/CMakeLists.txt
Normal file
15
examples/embd-input/CMakeLists.txt
Normal file
|
@ -0,0 +1,15 @@
|
|||
set(TARGET embdinput)
|
||||
add_library(${TARGET} embd-input-lib.cpp embd-input.h)
|
||||
target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT})
|
||||
target_compile_features(${TARGET} PRIVATE cxx_std_11)
|
||||
if(TARGET BUILD_INFO)
|
||||
add_dependencies(${TARGET} BUILD_INFO)
|
||||
endif()
|
||||
|
||||
set(TARGET embd-input-test)
|
||||
add_executable(${TARGET} embd-input-test.cpp)
|
||||
target_link_libraries(${TARGET} PRIVATE common llama embdinput ${CMAKE_THREAD_LIBS_INIT})
|
||||
target_compile_features(${TARGET} PRIVATE cxx_std_11)
|
||||
if(TARGET BUILD_INFO)
|
||||
add_dependencies(${TARGET} BUILD_INFO)
|
||||
endif()
|
63
examples/embd-input/README.md
Normal file
63
examples/embd-input/README.md
Normal file
|
@ -0,0 +1,63 @@
|
|||
### Examples for input embedding directly
|
||||
|
||||
## Requirement
|
||||
build `libembdinput.so`
|
||||
run the following comman in main dir (../../).
|
||||
```
|
||||
make
|
||||
```
|
||||
|
||||
## [LLaVA](https://github.com/haotian-liu/LLaVA/) example (llava.py)
|
||||
|
||||
1. Obtian LLaVA model (following https://github.com/haotian-liu/LLaVA/ , use https://huggingface.co/liuhaotian/LLaVA-13b-delta-v1-1/).
|
||||
2. Convert it to ggml format.
|
||||
3. `llava_projection.pth` is [pytorch_model-00003-of-00003.bin](https://huggingface.co/liuhaotian/LLaVA-13b-delta-v1-1/blob/main/pytorch_model-00003-of-00003.bin).
|
||||
|
||||
```
|
||||
import torch
|
||||
|
||||
bin_path = "../LLaVA-13b-delta-v1-1/pytorch_model-00003-of-00003.bin"
|
||||
pth_path = "./examples/embd_input/llava_projection.pth"
|
||||
|
||||
dic = torch.load(bin_path)
|
||||
used_key = ["model.mm_projector.weight","model.mm_projector.bias"]
|
||||
torch.save({k: dic[k] for k in used_key}, pth_path)
|
||||
```
|
||||
4. Check the path of LLaVA model and `llava_projection.pth` in `llava.py`.
|
||||
|
||||
|
||||
## [PandaGPT](https://github.com/yxuansu/PandaGPT) example (panda_gpt.py)
|
||||
|
||||
1. Obtian PandaGPT lora model from https://github.com/yxuansu/PandaGPT. Rename the file to `adapter_model.bin`. Use [convert-lora-to-ggml.py](../../convert-lora-to-ggml.py) to convert it to ggml format.
|
||||
The `adapter_config.json` is
|
||||
```
|
||||
{
|
||||
"peft_type": "LORA",
|
||||
"fan_in_fan_out": false,
|
||||
"bias": null,
|
||||
"modules_to_save": null,
|
||||
"r": 32,
|
||||
"lora_alpha": 32,
|
||||
"lora_dropout": 0.1,
|
||||
"target_modules": ["q_proj", "k_proj", "v_proj", "o_proj"]
|
||||
}
|
||||
```
|
||||
2. Papare the `vicuna` v0 model.
|
||||
3. Obtain the [ImageBind](https://dl.fbaipublicfiles.com/imagebind/imagebind_huge.pth) model.
|
||||
4. Clone the PandaGPT source.
|
||||
```
|
||||
git clone https://github.com/yxuansu/PandaGPT
|
||||
```
|
||||
5. Install the requirement of PandaGPT.
|
||||
6. Check the path of PandaGPT source, ImageBind model, lora model and vicuna model in panda_gpt.py.
|
||||
|
||||
## [MiniGPT-4](https://github.com/Vision-CAIR/MiniGPT-4/) example (minigpt4.py)
|
||||
|
||||
1. Obtain MiniGPT-4 model from https://github.com/Vision-CAIR/MiniGPT-4/ and put it in `embd-input`.
|
||||
2. Clone the MiniGPT-4 source.
|
||||
```
|
||||
git clone https://github.com/Vision-CAIR/MiniGPT-4/
|
||||
```
|
||||
3. Install the requirement of PandaGPT.
|
||||
4. Papare the `vicuna` v0 model.
|
||||
5. Check the path of MiniGPT-4 source, MiniGPT-4 model and vicuna model in `minigpt4.py`.
|
220
examples/embd-input/embd-input-lib.cpp
Normal file
220
examples/embd-input/embd-input-lib.cpp
Normal file
|
@ -0,0 +1,220 @@
|
|||
// Defines sigaction on msys:
|
||||
#ifndef _GNU_SOURCE
|
||||
#define _GNU_SOURCE
|
||||
#endif
|
||||
|
||||
#include "embd-input.h"
|
||||
|
||||
#include <cassert>
|
||||
#include <cinttypes>
|
||||
#include <cmath>
|
||||
#include <cstdio>
|
||||
#include <cstring>
|
||||
#include <ctime>
|
||||
#include <fstream>
|
||||
#include <iostream>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
static llama_context ** g_ctx;
|
||||
|
||||
extern "C" {
|
||||
|
||||
struct MyModel* create_mymodel(int argc, char ** argv) {
|
||||
gpt_params params;
|
||||
|
||||
if (gpt_params_parse(argc, argv, params) == false) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
fprintf(stderr, "%s: build = %d (%s)\n", __func__, BUILD_NUMBER, BUILD_COMMIT);
|
||||
|
||||
if (params.seed < 0) {
|
||||
params.seed = time(NULL);
|
||||
}
|
||||
fprintf(stderr, "%s: seed = %d\n", __func__, params.seed);
|
||||
|
||||
llama_init_backend(params.numa);
|
||||
|
||||
llama_model * model;
|
||||
llama_context * ctx;
|
||||
|
||||
g_ctx = &ctx;
|
||||
|
||||
// load the model and apply lora adapter, if any
|
||||
std::tie(model, ctx) = llama_init_from_gpt_params(params);
|
||||
if (model == NULL) {
|
||||
fprintf(stderr, "%s: error: unable to load model\n", __func__);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// print system information
|
||||
{
|
||||
fprintf(stderr, "\n");
|
||||
fprintf(stderr, "system_info: n_threads = %d / %d | %s\n",
|
||||
params.n_threads, std::thread::hardware_concurrency(), llama_print_system_info());
|
||||
}
|
||||
struct MyModel * ret = new MyModel();
|
||||
ret->ctx = ctx;
|
||||
ret->params = params;
|
||||
ret->n_past = 0;
|
||||
// printf("ctx: %d\n", ret->ctx);
|
||||
return ret;
|
||||
}
|
||||
|
||||
void free_mymodel(struct MyModel * mymodel) {
|
||||
llama_context * ctx = mymodel->ctx;
|
||||
llama_print_timings(ctx);
|
||||
llama_free(ctx);
|
||||
delete mymodel;
|
||||
}
|
||||
|
||||
|
||||
bool eval_float(void * model, float * input, int N){
|
||||
MyModel * mymodel = (MyModel*)model;
|
||||
llama_context * ctx = mymodel->ctx;
|
||||
gpt_params params = mymodel->params;
|
||||
int n_emb = llama_n_embd(ctx);
|
||||
int n_past = mymodel->n_past;
|
||||
int n_batch = N; // params.n_batch;
|
||||
|
||||
for (int i = 0; i < (int) N; i += n_batch) {
|
||||
int n_eval = (int) N - i;
|
||||
if (n_eval > n_batch) {
|
||||
n_eval = n_batch;
|
||||
}
|
||||
if (llama_eval_embd(ctx, (input+i*n_emb), n_eval, n_past, params.n_threads)) {
|
||||
fprintf(stderr, "%s : failed to eval\n", __func__);
|
||||
return false;
|
||||
}
|
||||
n_past += n_eval;
|
||||
}
|
||||
mymodel->n_past = n_past;
|
||||
return true;
|
||||
}
|
||||
|
||||
bool eval_tokens(void * model, std::vector<llama_token> tokens) {
|
||||
MyModel * mymodel = (MyModel* )model;
|
||||
llama_context * ctx;
|
||||
ctx = mymodel->ctx;
|
||||
gpt_params params = mymodel->params;
|
||||
int n_past = mymodel->n_past;
|
||||
for (int i = 0; i < (int) tokens.size(); i += params.n_batch) {
|
||||
int n_eval = (int) tokens.size() - i;
|
||||
if (n_eval > params.n_batch) {
|
||||
n_eval = params.n_batch;
|
||||
}
|
||||
if (llama_eval(ctx, &tokens[i], n_eval, n_past, params.n_threads)) {
|
||||
fprintf(stderr, "%s : failed to eval\n", __func__);
|
||||
return false;
|
||||
}
|
||||
n_past += n_eval;
|
||||
}
|
||||
mymodel->n_past = n_past;
|
||||
return true;
|
||||
}
|
||||
|
||||
bool eval_id(struct MyModel* mymodel, int id) {
|
||||
std::vector<llama_token> tokens;
|
||||
tokens.push_back(id);
|
||||
return eval_tokens(mymodel, tokens);
|
||||
}
|
||||
|
||||
bool eval_string(struct MyModel * mymodel,const char* str){
|
||||
llama_context * ctx = mymodel->ctx;
|
||||
std::string str2 = str;
|
||||
std::vector<llama_token> embd_inp = ::llama_tokenize(ctx, str2, true);
|
||||
eval_tokens(mymodel, embd_inp);
|
||||
return true;
|
||||
}
|
||||
|
||||
llama_token sampling_id(struct MyModel* mymodel) {
|
||||
llama_context* ctx = mymodel->ctx;
|
||||
gpt_params params = mymodel->params;
|
||||
// int n_ctx = llama_n_ctx(ctx);
|
||||
|
||||
// out of user input, sample next token
|
||||
const float temp = params.temp;
|
||||
const int32_t top_k = params.top_k <= 0 ? llama_n_vocab(ctx) : params.top_k;
|
||||
const float top_p = params.top_p;
|
||||
const float tfs_z = params.tfs_z;
|
||||
const float typical_p = params.typical_p;
|
||||
// const int32_t repeat_last_n = params.repeat_last_n < 0 ? n_ctx : params.repeat_last_n;
|
||||
// const float repeat_penalty = params.repeat_penalty;
|
||||
// const float alpha_presence = params.presence_penalty;
|
||||
// const float alpha_frequency = params.frequency_penalty;
|
||||
const int mirostat = params.mirostat;
|
||||
const float mirostat_tau = params.mirostat_tau;
|
||||
const float mirostat_eta = params.mirostat_eta;
|
||||
// const bool penalize_nl = params.penalize_nl;
|
||||
|
||||
llama_token id = 0;
|
||||
{
|
||||
auto logits = llama_get_logits(ctx);
|
||||
auto n_vocab = llama_n_vocab(ctx);
|
||||
|
||||
// Apply params.logit_bias map
|
||||
for (auto it = params.logit_bias.begin(); it != params.logit_bias.end(); it++) {
|
||||
logits[it->first] += it->second;
|
||||
}
|
||||
|
||||
std::vector<llama_token_data> candidates;
|
||||
candidates.reserve(n_vocab);
|
||||
for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
|
||||
candidates.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f});
|
||||
}
|
||||
|
||||
llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
|
||||
|
||||
// TODO: Apply penalties
|
||||
// float nl_logit = logits[llama_token_nl()];
|
||||
// auto last_n_repeat = std::min(std::min((int)last_n_tokens.size(), repeat_last_n), n_ctx);
|
||||
// llama_sample_repetition_penalty(ctx, &candidates_p,
|
||||
// last_n_tokens.data() + last_n_tokens.size() - last_n_repeat,
|
||||
// last_n_repeat, repeat_penalty);
|
||||
// llama_sample_frequency_and_presence_penalties(ctx, &candidates_p,
|
||||
// last_n_tokens.data() + last_n_tokens.size() - last_n_repeat,
|
||||
// last_n_repeat, alpha_frequency, alpha_presence);
|
||||
// if (!penalize_nl) {
|
||||
// logits[llama_token_nl()] = nl_logit;
|
||||
// }
|
||||
|
||||
if (temp <= 0) {
|
||||
// Greedy sampling
|
||||
id = llama_sample_token_greedy(ctx, &candidates_p);
|
||||
} else {
|
||||
if (mirostat == 1) {
|
||||
static float mirostat_mu = 2.0f * mirostat_tau;
|
||||
const int mirostat_m = 100;
|
||||
llama_sample_temperature(ctx, &candidates_p, temp);
|
||||
id = llama_sample_token_mirostat(ctx, &candidates_p, mirostat_tau, mirostat_eta, mirostat_m, &mirostat_mu);
|
||||
} else if (mirostat == 2) {
|
||||
static float mirostat_mu = 2.0f * mirostat_tau;
|
||||
llama_sample_temperature(ctx, &candidates_p, temp);
|
||||
id = llama_sample_token_mirostat_v2(ctx, &candidates_p, mirostat_tau, mirostat_eta, &mirostat_mu);
|
||||
} else {
|
||||
// Temperature sampling
|
||||
llama_sample_top_k(ctx, &candidates_p, top_k, 1);
|
||||
llama_sample_tail_free(ctx, &candidates_p, tfs_z, 1);
|
||||
llama_sample_typical(ctx, &candidates_p, typical_p, 1);
|
||||
llama_sample_top_p(ctx, &candidates_p, top_p, 1);
|
||||
llama_sample_temperature(ctx, &candidates_p, temp);
|
||||
id = llama_sample_token(ctx, &candidates_p);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return id;
|
||||
}
|
||||
|
||||
const char * sampling(struct MyModel * mymodel) {
|
||||
llama_context * ctx = mymodel->ctx;
|
||||
int id = sampling_id(mymodel);
|
||||
std::string ret;
|
||||
if (id == llama_token_eos()) ret = "</s>";
|
||||
else ret = llama_token_to_str(ctx, id);
|
||||
eval_id(mymodel, id);
|
||||
return ret.c_str();
|
||||
}
|
||||
|
||||
}
|
35
examples/embd-input/embd-input-test.cpp
Normal file
35
examples/embd-input/embd-input-test.cpp
Normal file
|
@ -0,0 +1,35 @@
|
|||
#include "embd-input.h"
|
||||
#include <stdlib.h>
|
||||
#include <random>
|
||||
#include <string.h>
|
||||
|
||||
int main(int argc, char** argv) {
|
||||
|
||||
auto mymodel = create_mymodel(argc, argv);
|
||||
int N = 10;
|
||||
int max_tgt_len = 500;
|
||||
int n_embd = llama_n_embd(mymodel->ctx);
|
||||
|
||||
// add random float embd to test evaluation
|
||||
float * data = new float[N*n_embd];
|
||||
std::default_random_engine e;
|
||||
std::uniform_real_distribution<float> u(0,1);
|
||||
for (int i=0;i<N*n_embd;i++) {
|
||||
data[i] = u(e);
|
||||
}
|
||||
|
||||
eval_string(mymodel, "user: what is the color of the flag of UN?");
|
||||
eval_float(mymodel, data, N);
|
||||
eval_string(mymodel, "assistant:");
|
||||
eval_string(mymodel, mymodel->params.prompt.c_str());
|
||||
const char* tmp;
|
||||
for (int i=0; i<max_tgt_len; i++) {
|
||||
tmp = sampling(mymodel);
|
||||
if (strcmp(tmp, "</s>")==0) break;
|
||||
printf("%s", tmp);
|
||||
fflush(stdout);
|
||||
}
|
||||
printf("\n");
|
||||
free_mymodel(mymodel);
|
||||
return 0;
|
||||
}
|
30
examples/embd-input/embd-input.h
Normal file
30
examples/embd-input/embd-input.h
Normal file
|
@ -0,0 +1,30 @@
|
|||
#ifndef _EMBD_INPUT_H_
|
||||
#define _EMBD_INPUT_H_ 1
|
||||
|
||||
#include "common.h"
|
||||
#include "llama.h"
|
||||
#include "build-info.h"
|
||||
|
||||
|
||||
extern "C" {
|
||||
|
||||
typedef struct MyModel {
|
||||
llama_context* ctx;
|
||||
gpt_params params;
|
||||
int n_past = 0;
|
||||
} MyModel;
|
||||
|
||||
|
||||
struct MyModel* create_mymodel(int argc, char ** argv);
|
||||
|
||||
bool eval_float(void* model, float* input, int N);
|
||||
bool eval_tokens(void* model, std::vector<llama_token> tokens);
|
||||
bool eval_id(struct MyModel* mymodel, int id);
|
||||
bool eval_string(struct MyModel* mymodel, const char* str);
|
||||
const char* sampling(struct MyModel* mymodel);
|
||||
llama_token sampling_id(struct MyModel* mymodel);
|
||||
void free_mymodel(struct MyModel* mymodel);
|
||||
|
||||
}
|
||||
|
||||
#endif
|
71
examples/embd-input/embd_input.py
Normal file
71
examples/embd-input/embd_input.py
Normal file
|
@ -0,0 +1,71 @@
|
|||
import ctypes
|
||||
from ctypes import cdll, c_char_p, c_void_p, POINTER, c_float, c_int
|
||||
import numpy as np
|
||||
import os
|
||||
|
||||
libc = cdll.LoadLibrary("./libembdinput.so")
|
||||
libc.sampling.restype=c_char_p
|
||||
libc.create_mymodel.restype=c_void_p
|
||||
libc.eval_string.argtypes=[c_void_p, c_char_p]
|
||||
libc.sampling.argtypes=[c_void_p]
|
||||
libc.eval_float.argtypes=[c_void_p, POINTER(c_float), c_int]
|
||||
|
||||
|
||||
class MyModel:
|
||||
def __init__(self, args):
|
||||
argc = len(args)
|
||||
c_str = [c_char_p(i.encode()) for i in args]
|
||||
args_c = (c_char_p * argc)(*c_str)
|
||||
self.model = c_void_p(libc.create_mymodel(argc, args_c))
|
||||
self.max_tgt_len = 512
|
||||
self.print_string_eval = True
|
||||
|
||||
def __del__(self):
|
||||
libc.free_mymodel(self.model)
|
||||
|
||||
def eval_float(self, x):
|
||||
libc.eval_float(self.model, x.astype(np.float32).ctypes.data_as(POINTER(c_float)), x.shape[1])
|
||||
|
||||
def eval_string(self, x):
|
||||
libc.eval_string(self.model, x.encode()) # c_char_p(x.encode()))
|
||||
if self.print_string_eval:
|
||||
print(x)
|
||||
|
||||
def eval_token(self, x):
|
||||
libc.eval_id(self.model, x)
|
||||
|
||||
def sampling(self):
|
||||
s = libc.sampling(self.model)
|
||||
return s
|
||||
|
||||
def stream_generate(self, end="</s>"):
|
||||
ret = b""
|
||||
end = end.encode()
|
||||
for _ in range(self.max_tgt_len):
|
||||
tmp = self.sampling()
|
||||
ret += tmp
|
||||
yield tmp
|
||||
if ret.endswith(end):
|
||||
break
|
||||
|
||||
def generate_with_print(self, end="</s>"):
|
||||
ret = b""
|
||||
for i in self.stream_generate(end=end):
|
||||
ret += i
|
||||
print(i.decode(errors="replace"), end="", flush=True)
|
||||
print("")
|
||||
return ret.decode(errors="replace")
|
||||
|
||||
|
||||
def generate(self, end="</s>"):
|
||||
text = b"".join(self.stream_generate(end=end))
|
||||
return text.decode(errors="replace")
|
||||
|
||||
if __name__ == "__main__":
|
||||
model = MyModel(["main", "--model", "../llama.cpp/models/ggml-vic13b-q4_1.bin", "-c", "2048"])
|
||||
model.eval_string("""user: what is the color of the flag of UN?""")
|
||||
x = np.random.random((5120,10))# , dtype=np.float32)
|
||||
model.eval_float(x)
|
||||
model.eval_string("""assistant:""")
|
||||
for i in model.generate():
|
||||
print(i.decode(errors="replace"), end="", flush=True)
|
70
examples/embd-input/llava.py
Normal file
70
examples/embd-input/llava.py
Normal file
|
@ -0,0 +1,70 @@
|
|||
import sys
|
||||
import os
|
||||
sys.path.insert(0, os.path.dirname(__file__))
|
||||
from embd_input import MyModel
|
||||
import numpy as np
|
||||
from torch import nn
|
||||
import torch
|
||||
from transformers import CLIPVisionModel, CLIPImageProcessor
|
||||
from PIL import Image
|
||||
|
||||
# model parameters from 'liuhaotian/LLaVA-13b-delta-v1-1'
|
||||
vision_tower = "openai/clip-vit-large-patch14"
|
||||
select_hidden_state_layer = -2
|
||||
# (vision_config.image_size // vision_config.patch_size) ** 2
|
||||
image_token_len = (224//14)**2
|
||||
|
||||
class Llava:
|
||||
def __init__(self, args):
|
||||
self.image_processor = CLIPImageProcessor.from_pretrained(vision_tower)
|
||||
self.vision_tower = CLIPVisionModel.from_pretrained(vision_tower)
|
||||
self.mm_projector = nn.Linear(1024, 5120)
|
||||
self.model = MyModel(["main", *args])
|
||||
|
||||
def load_projection(self, path):
|
||||
state = torch.load(path)
|
||||
self.mm_projector.load_state_dict({
|
||||
"weight": state["model.mm_projector.weight"],
|
||||
"bias": state["model.mm_projector.bias"]})
|
||||
|
||||
def chat(self, question):
|
||||
self.model.eval_string("user: ")
|
||||
self.model.eval_string(question)
|
||||
self.model.eval_string("\nassistant: ")
|
||||
return self.model.generate_with_print()
|
||||
|
||||
def chat_with_image(self, image, question):
|
||||
with torch.no_grad():
|
||||
embd_image = self.image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
|
||||
image_forward_out = self.vision_tower(embd_image.unsqueeze(0), output_hidden_states=True)
|
||||
select_hidden_state = image_forward_out.hidden_states[select_hidden_state_layer]
|
||||
image_feature = select_hidden_state[:, 1:]
|
||||
embd_image = self.mm_projector(image_feature)
|
||||
embd_image = embd_image.cpu().numpy()[0]
|
||||
self.model.eval_string("user: ")
|
||||
self.model.eval_token(32003-2) # im_start
|
||||
self.model.eval_float(embd_image.T)
|
||||
for i in range(image_token_len-embd_image.shape[0]):
|
||||
self.model.eval_token(32003-3) # im_patch
|
||||
self.model.eval_token(32003-1) # im_end
|
||||
self.model.eval_string(question)
|
||||
self.model.eval_string("\nassistant: ")
|
||||
return self.model.generate_with_print()
|
||||
|
||||
|
||||
if __name__=="__main__":
|
||||
# model form liuhaotian/LLaVA-13b-delta-v1-1
|
||||
a = Llava(["--model", "./models/ggml-llava-13b-v1.1.bin", "-c", "2048"])
|
||||
# Extract from https://huggingface.co/liuhaotian/LLaVA-13b-delta-v1-1/blob/main/pytorch_model-00003-of-00003.bin.
|
||||
# Also here can use pytorch_model-00003-of-00003.bin directly.
|
||||
a.load_projection(os.path.join(
|
||||
os.path.dirname(__file__) ,
|
||||
"llava_projetion.pth"))
|
||||
respose = a.chat_with_image(
|
||||
Image.open("./media/llama1-logo.png").convert('RGB'),
|
||||
"what is the text in the picture?")
|
||||
respose
|
||||
a.chat("what is the color of it?")
|
||||
|
||||
|
||||
|
128
examples/embd-input/minigpt4.py
Normal file
128
examples/embd-input/minigpt4.py
Normal file
|
@ -0,0 +1,128 @@
|
|||
import sys
|
||||
import os
|
||||
sys.path.insert(0, os.path.dirname(__file__))
|
||||
from embd_input import MyModel
|
||||
import numpy as np
|
||||
from torch import nn
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
minigpt4_path = os.path.join(os.path.dirname(__file__), "MiniGPT-4")
|
||||
sys.path.insert(0, minigpt4_path)
|
||||
from minigpt4.models.blip2 import Blip2Base
|
||||
from minigpt4.processors.blip_processors import Blip2ImageEvalProcessor
|
||||
|
||||
|
||||
class MiniGPT4(Blip2Base):
|
||||
"""
|
||||
MiniGPT4 model from https://github.com/Vision-CAIR/MiniGPT-4
|
||||
"""
|
||||
def __init__(self,
|
||||
args,
|
||||
vit_model="eva_clip_g",
|
||||
q_former_model="https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP2/blip2_pretrained_flant5xxl.pth",
|
||||
img_size=224,
|
||||
drop_path_rate=0,
|
||||
use_grad_checkpoint=False,
|
||||
vit_precision="fp32",
|
||||
freeze_vit=True,
|
||||
freeze_qformer=True,
|
||||
num_query_token=32,
|
||||
llama_model="",
|
||||
prompt_path="",
|
||||
prompt_template="",
|
||||
max_txt_len=32,
|
||||
end_sym='\n',
|
||||
low_resource=False, # use 8 bit and put vit in cpu
|
||||
device_8bit=0
|
||||
):
|
||||
super().__init__()
|
||||
self.img_size = img_size
|
||||
self.low_resource = low_resource
|
||||
self.preprocessor = Blip2ImageEvalProcessor(img_size)
|
||||
|
||||
print('Loading VIT')
|
||||
self.visual_encoder, self.ln_vision = self.init_vision_encoder(
|
||||
vit_model, img_size, drop_path_rate, use_grad_checkpoint, vit_precision
|
||||
)
|
||||
print('Loading VIT Done')
|
||||
print('Loading Q-Former')
|
||||
self.Qformer, self.query_tokens = self.init_Qformer(
|
||||
num_query_token, self.visual_encoder.num_features
|
||||
)
|
||||
self.Qformer.cls = None
|
||||
self.Qformer.bert.embeddings.word_embeddings = None
|
||||
self.Qformer.bert.embeddings.position_embeddings = None
|
||||
for layer in self.Qformer.bert.encoder.layer:
|
||||
layer.output = None
|
||||
layer.intermediate = None
|
||||
self.load_from_pretrained(url_or_filename=q_former_model)
|
||||
print('Loading Q-Former Done')
|
||||
self.llama_proj = nn.Linear(
|
||||
self.Qformer.config.hidden_size, 5120 # self.llama_model.config.hidden_size
|
||||
)
|
||||
self.max_txt_len = max_txt_len
|
||||
self.end_sym = end_sym
|
||||
self.model = MyModel(["main", *args])
|
||||
# system promt
|
||||
self.model.eval_string("Give the following image: <Img>ImageContent</Img>. "
|
||||
"You will be able to see the image once I provide it to you. Please answer my questions."
|
||||
"###")
|
||||
|
||||
def encode_img(self, image):
|
||||
image = self.preprocessor(image)
|
||||
image = image.unsqueeze(0)
|
||||
device = image.device
|
||||
if self.low_resource:
|
||||
self.vit_to_cpu()
|
||||
image = image.to("cpu")
|
||||
|
||||
with self.maybe_autocast():
|
||||
image_embeds = self.ln_vision(self.visual_encoder(image)).to(device)
|
||||
image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(device)
|
||||
|
||||
query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
|
||||
query_output = self.Qformer.bert(
|
||||
query_embeds=query_tokens,
|
||||
encoder_hidden_states=image_embeds,
|
||||
encoder_attention_mask=image_atts,
|
||||
return_dict=True,
|
||||
)
|
||||
|
||||
inputs_llama = self.llama_proj(query_output.last_hidden_state)
|
||||
# atts_llama = torch.ones(inputs_llama.size()[:-1], dtype=torch.long).to(image.device)
|
||||
return inputs_llama
|
||||
|
||||
def load_projection(self, path):
|
||||
state = torch.load(path)["model"]
|
||||
self.llama_proj.load_state_dict({
|
||||
"weight": state["llama_proj.weight"],
|
||||
"bias": state["llama_proj.bias"]})
|
||||
|
||||
def chat(self, question):
|
||||
self.model.eval_string("Human: ")
|
||||
self.model.eval_string(question)
|
||||
self.model.eval_string("\n### Assistant:")
|
||||
return self.model.generate_with_print(end="###")
|
||||
|
||||
def chat_with_image(self, image, question):
|
||||
with torch.no_grad():
|
||||
embd_image = self.encode_img(image)
|
||||
embd_image = embd_image.cpu().numpy()[0]
|
||||
self.model.eval_string("Human: <Img>")
|
||||
self.model.eval_float(embd_image.T)
|
||||
self.model.eval_string("</Img> ")
|
||||
self.model.eval_string(question)
|
||||
self.model.eval_string("\n### Assistant:")
|
||||
return self.model.generate_with_print(end="###")
|
||||
|
||||
|
||||
if __name__=="__main__":
|
||||
a = MiniGPT4(["--model", "./models/ggml-vicuna-13b-v0-q4_1.bin", "-c", "2048"])
|
||||
a.load_projection(os.path.join(
|
||||
os.path.dirname(__file__) ,
|
||||
"pretrained_minigpt4.pth"))
|
||||
respose = a.chat_with_image(
|
||||
Image.open("./media/llama1-logo.png").convert('RGB'),
|
||||
"what is the text in the picture?")
|
||||
a.chat("what is the color of it?")
|
98
examples/embd-input/panda_gpt.py
Normal file
98
examples/embd-input/panda_gpt.py
Normal file
|
@ -0,0 +1,98 @@
|
|||
import sys
|
||||
import os
|
||||
sys.path.insert(0, os.path.dirname(__file__))
|
||||
from embd_input import MyModel
|
||||
import numpy as np
|
||||
from torch import nn
|
||||
import torch
|
||||
|
||||
# use PandaGPT path
|
||||
panda_gpt_path = os.path.join(os.path.dirname(__file__), "PandaGPT")
|
||||
imagebind_ckpt_path = "./models/panda_gpt/"
|
||||
|
||||
sys.path.insert(0, os.path.join(panda_gpt_path,"code","model"))
|
||||
from ImageBind.models import imagebind_model
|
||||
from ImageBind import data
|
||||
|
||||
ModalityType = imagebind_model.ModalityType
|
||||
max_tgt_len = 400
|
||||
|
||||
class PandaGPT:
|
||||
def __init__(self, args):
|
||||
self.visual_encoder,_ = imagebind_model.imagebind_huge(pretrained=True, store_path=imagebind_ckpt_path)
|
||||
self.visual_encoder.eval()
|
||||
self.llama_proj = nn.Linear(1024, 5120) # self.visual_hidden_size, 5120)
|
||||
self.max_tgt_len = max_tgt_len
|
||||
self.model = MyModel(["main", *args])
|
||||
self.generated_text = ""
|
||||
self.device = "cpu"
|
||||
|
||||
def load_projection(self, path):
|
||||
state = torch.load(path, map_location="cpu")
|
||||
self.llama_proj.load_state_dict({
|
||||
"weight": state["llama_proj.weight"],
|
||||
"bias": state["llama_proj.bias"]})
|
||||
|
||||
def eval_inputs(self, inputs):
|
||||
self.model.eval_string("<Img>")
|
||||
embds = self.extract_multimoal_feature(inputs)
|
||||
for i in embds:
|
||||
self.model.eval_float(i.T)
|
||||
self.model.eval_string("</Img> ")
|
||||
|
||||
def chat(self, question):
|
||||
return self.chat_with_image(None, question)
|
||||
|
||||
def chat_with_image(self, inputs, question):
|
||||
if self.generated_text == "":
|
||||
self.model.eval_string("###")
|
||||
self.model.eval_string(" Human: ")
|
||||
if inputs:
|
||||
self.eval_inputs(inputs)
|
||||
self.model.eval_string(question)
|
||||
self.model.eval_string("\n### Assistant:")
|
||||
ret = self.model.generate_with_print(end="###")
|
||||
self.generated_text += ret
|
||||
return ret
|
||||
|
||||
def extract_multimoal_feature(self, inputs):
|
||||
features = []
|
||||
for key in ["image", "audio", "video", "thermal"]:
|
||||
if key + "_paths" in inputs:
|
||||
embeds = self.encode_data(key, inputs[key+"_paths"])
|
||||
features.append(embeds)
|
||||
return features
|
||||
|
||||
def encode_data(self, data_type, data_paths):
|
||||
|
||||
type_map = {
|
||||
"image": ModalityType.VISION,
|
||||
"audio": ModalityType.AUDIO,
|
||||
"video": ModalityType.VISION,
|
||||
"thermal": ModalityType.THERMAL,
|
||||
}
|
||||
load_map = {
|
||||
"image": data.load_and_transform_vision_data,
|
||||
"audio": data.load_and_transform_audio_data,
|
||||
"video": data.load_and_transform_video_data,
|
||||
"thermal": data.load_and_transform_thermal_data
|
||||
}
|
||||
|
||||
load_function = load_map[data_type]
|
||||
key = type_map[data_type]
|
||||
|
||||
inputs = {key: load_function(data_paths, self.device)}
|
||||
with torch.no_grad():
|
||||
embeddings = self.visual_encoder(inputs)
|
||||
embeds = embeddings[key]
|
||||
embeds = self.llama_proj(embeds).cpu().numpy()
|
||||
return embeds
|
||||
|
||||
|
||||
if __name__=="__main__":
|
||||
a = PandaGPT(["--model", "./models/ggml-vicuna-13b-v0-q4_1.bin", "-c", "2048", "--lora", "./models/panda_gpt/ggml-adapter-model.bin","--temp", "0"])
|
||||
a.load_projection("./models/panda_gpt/adapter_model.bin")
|
||||
a.chat_with_image(
|
||||
{"image_paths": ["./media/llama1-logo.png"]},
|
||||
"what is the text in the picture? 'llama' or 'lambda'?")
|
||||
a.chat("what is the color of it?")
|
65
ggml-cuda.cu
65
ggml-cuda.cu
|
@ -223,6 +223,15 @@ static __global__ void add_f32(const float * x, const float * y, float * dst, co
|
|||
dst[i] = x[i] + y[i];
|
||||
}
|
||||
|
||||
static __global__ void add_f16_f32_f16(const half * x, const float * y, half * dst, const int k) {
|
||||
const int i = blockDim.x*blockIdx.x + threadIdx.x;
|
||||
|
||||
if (i >= k) {
|
||||
return;
|
||||
}
|
||||
dst[i] = __hadd(x[i], __float2half(y[i]));
|
||||
}
|
||||
|
||||
static __global__ void mul_f32(const float * x, const float * y, float * dst, const int kx, const int ky) {
|
||||
const int i = blockDim.x*blockIdx.x + threadIdx.x;
|
||||
|
||||
|
@ -1235,7 +1244,7 @@ static __global__ void dequantize_mul_mat_vec(const void * vx, const dfloat * y,
|
|||
}
|
||||
|
||||
static __global__ void mul_mat_p021_f16_f32(const void * vx, const float * y, float * dst, const int ncols_x, const int nrows_x, const int nchannels_x) {
|
||||
const half * x = (half *) vx;
|
||||
const half * x = (const half *) vx;
|
||||
|
||||
const int row_x = blockDim.y*blockIdx.y + threadIdx.y;
|
||||
const int channel = blockDim.z*blockIdx.z + threadIdx.z;
|
||||
|
@ -1283,9 +1292,9 @@ static __global__ void mul_mat_p021_f16_f32(const void * vx, const float * y, fl
|
|||
|
||||
static __global__ void mul_mat_vec_nc_f16_f32( // nc == non-contiguous
|
||||
const void * vx, const float * y, float * dst, const int ncols_x, const int nrows_x,
|
||||
const int row_stride_x, const int nchannels_x, const int channel_stride_x) {
|
||||
const int row_stride_x, const int channel_stride_x) {
|
||||
|
||||
const half * x = (half *) vx;
|
||||
const half * x = (const half *) vx;
|
||||
|
||||
const int row_x = blockDim.y*blockIdx.y + threadIdx.y;
|
||||
const int channel = blockDim.z*blockIdx.z + threadIdx.z;
|
||||
|
@ -1328,14 +1337,14 @@ static __global__ void mul_mat_vec_nc_f16_f32( // nc == non-contiguous
|
|||
}
|
||||
|
||||
static __device__ void cpy_1_f32_f32(const char * cxi, char * cdsti) {
|
||||
const float * xi = (float *) cxi;
|
||||
const float * xi = (const float *) cxi;
|
||||
float * dsti = (float *) cdsti;
|
||||
|
||||
*dsti = *xi;
|
||||
}
|
||||
|
||||
static __device__ void cpy_1_f32_f16(const char * cxi, char * cdsti) {
|
||||
const float * xi = (float *) cxi;
|
||||
const float * xi = (const float *) cxi;
|
||||
half * dsti = (half *) cdsti;
|
||||
|
||||
*dsti = __float2half(*xi);
|
||||
|
@ -1459,6 +1468,11 @@ static void add_f32_cuda(const float * x, const float * y, float * dst, const in
|
|||
add_f32<<<num_blocks, CUDA_ADD_BLOCK_SIZE, 0, stream>>>(x, y, dst, k);
|
||||
}
|
||||
|
||||
static void add_f16_f32_f16_cuda(const half * x, const float * y, half * dst, const int k, cudaStream_t stream) {
|
||||
const int num_blocks = (k + CUDA_ADD_BLOCK_SIZE - 1) / CUDA_ADD_BLOCK_SIZE;
|
||||
add_f16_f32_f16<<<num_blocks, CUDA_ADD_BLOCK_SIZE, 0, stream>>>(x, y, dst, k);
|
||||
}
|
||||
|
||||
static void mul_f32_cuda(const float * x, const float * y, float * dst, const int kx, const int ky, cudaStream_t stream) {
|
||||
const int num_blocks = (kx + CUDA_MUL_BLOCK_SIZE - 1) / CUDA_MUL_BLOCK_SIZE;
|
||||
mul_f32<<<num_blocks, CUDA_MUL_BLOCK_SIZE, 0, stream>>>(x, y, dst, kx, ky);
|
||||
|
@ -1684,7 +1698,7 @@ static void ggml_mul_mat_vec_nc_f16_f32_cuda(
|
|||
const dim3 block_nums(1, nrows_x, nchannels_x);
|
||||
const dim3 block_dims(WARP_SIZE, 1, 1);
|
||||
mul_mat_vec_nc_f16_f32<<<block_nums, block_dims, 0, stream>>>
|
||||
(vx, y, dst, ncols_x, nrows_x, row_stride_x, nchannels_x, channel_stride_x);
|
||||
(vx, y, dst, ncols_x, nrows_x, row_stride_x, channel_stride_x);
|
||||
}
|
||||
|
||||
static void ggml_cpy_f32_f32_cuda(
|
||||
|
@ -1941,7 +1955,7 @@ inline void ggml_cuda_op_add(
|
|||
float * src0_ddf_i, float * src1_ddf_i, float * dst_ddf_i, int64_t i02, int64_t i01_low, int64_t i01_high, int i1,
|
||||
cudaStream_t & cudaStream_main){
|
||||
|
||||
GGML_ASSERT(src0_ddf_i != nullptr);
|
||||
GGML_ASSERT(src0_ddq_i != nullptr || src0_ddf_i != nullptr);
|
||||
GGML_ASSERT(src1_ddf_i != nullptr);
|
||||
GGML_ASSERT(dst_ddf_i != nullptr);
|
||||
|
||||
|
@ -1949,7 +1963,13 @@ inline void ggml_cuda_op_add(
|
|||
const int64_t i01_diff = i01_high - i01_low;
|
||||
|
||||
// compute
|
||||
add_f32_cuda(src0_ddf_i, src1_ddf_i, dst_ddf_i, ne0*i01_diff, cudaStream_main);
|
||||
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
|
||||
add_f32_cuda(src0_ddf_i, src1_ddf_i, dst_ddf_i, ne0*i01_diff, cudaStream_main);
|
||||
} else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) {
|
||||
add_f16_f32_f16_cuda((half *) src0_ddq_i, src1_ddf_i, (half *) dst_ddf_i, ne0*i01_diff, cudaStream_main);
|
||||
} else {
|
||||
GGML_ASSERT(false);
|
||||
}
|
||||
CUDA_CHECK(cudaGetLastError());
|
||||
|
||||
(void) src1;
|
||||
|
@ -2547,8 +2567,14 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm
|
|||
}
|
||||
|
||||
void ggml_cuda_add(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32);
|
||||
ggml_cuda_op(src0, src1, dst, ggml_cuda_op_add, true, true);
|
||||
// ggml_cuda_add permits f16 dst even though this could in theory cause problems with the pointer arithmetic in ggml_cuda_op.
|
||||
// Due to flatten_rows == true this does in practice not make a difference however.
|
||||
// Better solution would be nice but right now that would require disproportionate changes.
|
||||
GGML_ASSERT(
|
||||
(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16) &&
|
||||
src1->type == GGML_TYPE_F32 &&
|
||||
(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16));
|
||||
ggml_cuda_op(src0, src1, dst, ggml_cuda_op_add, false, true);
|
||||
}
|
||||
|
||||
void ggml_cuda_mul(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
|
@ -2801,7 +2827,7 @@ void ggml_cuda_free_data(struct ggml_tensor * tensor) {
|
|||
delete extra;
|
||||
}
|
||||
|
||||
void ggml_cuda_assign_buffers_impl(struct ggml_tensor * tensor, bool scratch) {
|
||||
void ggml_cuda_assign_buffers_impl(struct ggml_tensor * tensor, bool scratch, bool force_inplace) {
|
||||
if (scratch && g_scratch_size == 0) {
|
||||
return;
|
||||
}
|
||||
|
@ -2810,11 +2836,11 @@ void ggml_cuda_assign_buffers_impl(struct ggml_tensor * tensor, bool scratch) {
|
|||
if (tensor->src0 != nullptr && tensor->src0->backend == GGML_BACKEND_CPU) {
|
||||
const ggml_op src0_op = tensor->src0->op;
|
||||
if (src0_op == GGML_OP_RESHAPE || src0_op == GGML_OP_TRANSPOSE || src0_op == GGML_OP_VIEW) {
|
||||
ggml_cuda_assign_buffers_impl(tensor->src0, scratch);
|
||||
ggml_cuda_assign_buffers_impl(tensor->src0, scratch, force_inplace);
|
||||
}
|
||||
}
|
||||
if (tensor->op == GGML_OP_CPY && tensor->src1->backend == GGML_BACKEND_CPU) {
|
||||
ggml_cuda_assign_buffers_impl(tensor->src1, scratch);
|
||||
ggml_cuda_assign_buffers_impl(tensor->src1, scratch, force_inplace);
|
||||
}
|
||||
|
||||
tensor->backend = GGML_BACKEND_GPU;
|
||||
|
@ -2822,11 +2848,12 @@ void ggml_cuda_assign_buffers_impl(struct ggml_tensor * tensor, bool scratch) {
|
|||
memset(extra, 0, sizeof(*extra));
|
||||
|
||||
const bool inplace = (tensor->src0 != nullptr && tensor->src0->data == tensor->data) ||
|
||||
tensor->op == GGML_OP_VIEW;
|
||||
tensor->op == GGML_OP_VIEW ||
|
||||
force_inplace;
|
||||
const size_t size = ggml_nbytes(tensor);
|
||||
|
||||
CUDA_CHECK(cudaSetDevice(g_main_device));
|
||||
if (inplace && tensor->src0->backend == GGML_BACKEND_GPU) {
|
||||
if (inplace && (tensor->src0->backend == GGML_BACKEND_GPU || tensor->src0->backend == GGML_BACKEND_GPU_SPLIT)) {
|
||||
struct ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu * ) tensor->src0->extra;
|
||||
char * src0_ddc = (char *) src0_extra->data_device[g_main_device];
|
||||
size_t offset = 0;
|
||||
|
@ -2865,11 +2892,15 @@ void ggml_cuda_assign_buffers_impl(struct ggml_tensor * tensor, bool scratch) {
|
|||
}
|
||||
|
||||
void ggml_cuda_assign_buffers(struct ggml_tensor * tensor) {
|
||||
ggml_cuda_assign_buffers_impl(tensor, true);
|
||||
ggml_cuda_assign_buffers_impl(tensor, true, false);
|
||||
}
|
||||
|
||||
void ggml_cuda_assign_buffers_no_scratch(struct ggml_tensor * tensor) {
|
||||
ggml_cuda_assign_buffers_impl(tensor, false);
|
||||
ggml_cuda_assign_buffers_impl(tensor, false, false);
|
||||
}
|
||||
|
||||
void ggml_cuda_assign_buffers_force_inplace(struct ggml_tensor * tensor) {
|
||||
ggml_cuda_assign_buffers_impl(tensor, false, true);
|
||||
}
|
||||
|
||||
void ggml_cuda_set_main_device(int main_device) {
|
||||
|
|
|
@ -29,6 +29,7 @@ void ggml_cuda_transform_tensor(void * data, struct ggml_tensor * tensor);
|
|||
void ggml_cuda_free_data(struct ggml_tensor * tensor);
|
||||
void ggml_cuda_assign_buffers(struct ggml_tensor * tensor);
|
||||
void ggml_cuda_assign_buffers_no_scratch(struct ggml_tensor * tensor);
|
||||
void ggml_cuda_assign_buffers_force_inplace(struct ggml_tensor * tensor);
|
||||
void ggml_cuda_set_main_device(int main_device);
|
||||
void ggml_cuda_set_scratch_size(size_t scratch_size);
|
||||
void ggml_cuda_free_scratch(void);
|
||||
|
|
3
ggml.c
3
ggml.c
|
@ -16684,7 +16684,8 @@ typedef pthread_t ggml_thread_t;
|
|||
|
||||
#endif
|
||||
|
||||
#ifdef __linux__
|
||||
// Android's libc implementation "bionic" does not support setting affinity
|
||||
#if defined(__linux__) && !defined(__BIONIC__)
|
||||
void set_numa_thread_affinity(int thread_n, int n_threads) {
|
||||
if (!ggml_is_numa()) {
|
||||
return;
|
||||
|
|
341
llama.cpp
341
llama.cpp
|
@ -364,96 +364,14 @@ static size_t llama_calc_tensor_size(const std::vector<uint32_t> & ne, enum ggml
|
|||
return size / ggml_blck_size(type);
|
||||
}
|
||||
|
||||
struct llama_load_tensor_shard {
|
||||
std::vector<uint32_t> ne;
|
||||
size_t size;
|
||||
enum ggml_type type;
|
||||
size_t file_idx;
|
||||
size_t file_off;
|
||||
|
||||
void calc_size() {
|
||||
size = llama_calc_tensor_size(ne, type);
|
||||
}
|
||||
};
|
||||
|
||||
enum llama_split_type {
|
||||
SPLIT_NONE,
|
||||
SPLIT_BY_COLUMNS,
|
||||
SPLIT_BY_ROWS
|
||||
};
|
||||
|
||||
struct llama_load_tensor {
|
||||
std::vector<llama_load_tensor_shard> shards;
|
||||
|
||||
std::string name;
|
||||
enum ggml_type type = GGML_TYPE_F32;
|
||||
llama_split_type split_type = SPLIT_NONE;
|
||||
std::vector<uint32_t> ne;
|
||||
size_t file_off;
|
||||
size_t size;
|
||||
struct ggml_tensor * ggml_tensor = NULL;
|
||||
uint8_t * data;
|
||||
|
||||
llama_load_tensor(const std::string & name) : name(name) {}
|
||||
|
||||
void calc_all() {
|
||||
calc_type();
|
||||
calc_split_type();
|
||||
calc_ne();
|
||||
calc_size();
|
||||
}
|
||||
|
||||
void calc_type() {
|
||||
const auto & first_shard = shards.at(0);
|
||||
for (const auto & shard : shards) {
|
||||
if (shard.type != first_shard.type) {
|
||||
throw std::runtime_error(format("inconsistent tensor shard type in '%s'", name.c_str()));
|
||||
}
|
||||
}
|
||||
type = first_shard.type;
|
||||
}
|
||||
|
||||
void calc_split_type() {
|
||||
if (shards.at(0).ne.size() == 1 || // 1D tensors are just duplicated in every file
|
||||
shards.size() == 1) { // only one file?
|
||||
split_type = SPLIT_NONE;
|
||||
} else if (name.find("tok_embeddings.") == 0 ||
|
||||
name.find(".attention.wo.weight") != std::string::npos ||
|
||||
name.find(".feed_forward.w2.weight") != std::string::npos) {
|
||||
split_type = SPLIT_BY_COLUMNS;
|
||||
} else {
|
||||
split_type = SPLIT_BY_ROWS;
|
||||
}
|
||||
}
|
||||
|
||||
void calc_ne() {
|
||||
const auto & first_shard = shards.at(0);
|
||||
for (const auto & shard : shards) {
|
||||
if (shard.ne != first_shard.ne) {
|
||||
throw std::runtime_error(format("inconsistent tensor shard shape in '%s': first was %s, other was %s",
|
||||
name.c_str(), llama_format_tensor_shape(first_shard.ne).c_str(), llama_format_tensor_shape(shard.ne).c_str()));
|
||||
}
|
||||
}
|
||||
ne = first_shard.ne;
|
||||
LLAMA_ASSERT(shards.size() <= UINT32_MAX);
|
||||
uint32_t n_shards = (uint32_t) shards.size();
|
||||
switch (split_type) {
|
||||
case SPLIT_NONE:
|
||||
ne = first_shard.ne;
|
||||
break;
|
||||
case SPLIT_BY_COLUMNS:
|
||||
ne = {checked_mul<uint32_t>(first_shard.ne[0], n_shards),
|
||||
first_shard.ne[1]};
|
||||
break;
|
||||
case SPLIT_BY_ROWS:
|
||||
ne = {first_shard.ne[0],
|
||||
checked_mul<uint32_t>(first_shard.ne[1], n_shards)};
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
void calc_size() {
|
||||
size = llama_calc_tensor_size(ne, type);
|
||||
}
|
||||
};
|
||||
|
||||
struct llama_load_tensors_map {
|
||||
|
@ -476,13 +394,13 @@ struct llama_file_loader {
|
|||
llama_hparams hparams;
|
||||
llama_vocab vocab;
|
||||
|
||||
llama_file_loader(const char * fname, size_t file_idx, llama_load_tensors_map & tensors_map)
|
||||
llama_file_loader(const char * fname, llama_load_tensors_map & tensors_map)
|
||||
: file(fname, "rb") {
|
||||
fprintf(stderr, "llama.cpp: loading model from %s\n", fname);
|
||||
read_magic();
|
||||
read_hparams();
|
||||
read_vocab();
|
||||
read_tensor_metadata(file_idx, tensors_map);
|
||||
read_tensor_metadata(tensors_map);
|
||||
}
|
||||
void read_magic() {
|
||||
uint32_t magic = file.read_u32();
|
||||
|
@ -539,19 +457,19 @@ struct llama_file_loader {
|
|||
tok_score.score = score;
|
||||
}
|
||||
}
|
||||
void read_tensor_metadata(size_t file_idx, llama_load_tensors_map & tensors_map) {
|
||||
void read_tensor_metadata(llama_load_tensors_map & tensors_map) {
|
||||
while (file.tell() < file.size) {
|
||||
llama_load_tensor_shard shard;
|
||||
llama_load_tensor tensor;
|
||||
uint32_t n_dims = file.read_u32();
|
||||
uint32_t name_len = file.read_u32();
|
||||
shard.type = (enum ggml_type) file.read_u32();
|
||||
shard.ne.resize(n_dims);
|
||||
file.read_raw(shard.ne.data(), sizeof(shard.ne[0]) * n_dims);
|
||||
tensor.type = (enum ggml_type) file.read_u32();
|
||||
tensor.ne.resize(n_dims);
|
||||
file.read_raw(tensor.ne.data(), sizeof(tensor.ne[0]) * n_dims);
|
||||
std::string name = file.read_string(name_len);
|
||||
if (n_dims < 1 || n_dims > 2) {
|
||||
throw std::runtime_error(format("llama.cpp: tensor '%s' should not be %u-dimensional", name.c_str(), n_dims));
|
||||
}
|
||||
switch (shard.type) {
|
||||
switch (tensor.type) {
|
||||
case GGML_TYPE_F32:
|
||||
case GGML_TYPE_F16:
|
||||
case GGML_TYPE_Q4_0:
|
||||
|
@ -566,30 +484,20 @@ struct llama_file_loader {
|
|||
case GGML_TYPE_Q6_K:
|
||||
break;
|
||||
default: {
|
||||
throw std::runtime_error(format("unrecognized tensor type %u\n", shard.type));
|
||||
throw std::runtime_error(format("unrecognized tensor type %u\n", tensor.type));
|
||||
}
|
||||
}
|
||||
|
||||
if (file_version >= LLAMA_FILE_VERSION_GGJT_V1) {
|
||||
// skip to the next multiple of 32 bytes
|
||||
file.seek(-static_cast<ptrdiff_t>(file.tell()) & 31, SEEK_CUR);
|
||||
}
|
||||
shard.file_idx = file_idx;
|
||||
shard.file_off = file.tell();
|
||||
// skip to the next multiple of 32 bytes
|
||||
file.seek(-static_cast<ptrdiff_t>(file.tell()) & 31, SEEK_CUR);
|
||||
|
||||
shard.calc_size();
|
||||
file.seek(shard.size, SEEK_CUR);
|
||||
tensor.file_off = file.tell();
|
||||
tensor.name = name;
|
||||
tensor.size = llama_calc_tensor_size(tensor.ne, tensor.type);
|
||||
file.seek(tensor.size, SEEK_CUR);
|
||||
|
||||
auto it = tensors_map.name_to_idx.find(name);
|
||||
size_t idx;
|
||||
if (it != tensors_map.name_to_idx.end()) {
|
||||
idx = it->second;
|
||||
} else {
|
||||
tensors_map.tensors.emplace_back(name);
|
||||
idx = tensors_map.tensors.size() - 1;
|
||||
tensors_map.name_to_idx.emplace(name, idx);
|
||||
}
|
||||
tensors_map.tensors.at(idx).shards.push_back(shard);
|
||||
tensors_map.tensors.push_back(tensor);
|
||||
tensors_map.name_to_idx[name] = tensors_map.tensors.size() - 1;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
@ -659,56 +567,19 @@ struct llama_file_saver {
|
|||
};
|
||||
|
||||
struct llama_model_loader {
|
||||
std::vector<std::unique_ptr<llama_file_loader>> file_loaders;
|
||||
std::unique_ptr<llama_file_loader> file_loader;
|
||||
llama_load_tensors_map tensors_map;
|
||||
bool use_mmap;
|
||||
size_t num_ggml_tensors_created = 0;
|
||||
struct ggml_context * ggml_ctx = NULL;
|
||||
std::unique_ptr<llama_mmap> mapping;
|
||||
|
||||
llama_model_loader(const std::string & fname_base, bool use_mmap, bool vocab_only) {
|
||||
auto * first_file = new llama_file_loader(fname_base.c_str(), 0, tensors_map);
|
||||
file_loaders.emplace_back(first_file);
|
||||
uint32_t n_parts = vocab_only ? 1 : guess_n_parts();
|
||||
for (uint32_t i = 1; i < n_parts; i++) {
|
||||
std::string fname = fname_base + "." + std::to_string(i);
|
||||
auto * ith_file = new llama_file_loader(fname.c_str(), i, tensors_map);
|
||||
file_loaders.emplace_back(ith_file);
|
||||
if (ith_file->hparams != first_file->hparams) {
|
||||
throw std::runtime_error(format("llama.cpp: hparams inconsistent between files"));
|
||||
}
|
||||
}
|
||||
llama_model_loader(const std::string & fname_base, bool use_mmap) {
|
||||
file_loader = std::unique_ptr<llama_file_loader>(new llama_file_loader(fname_base.c_str(), tensors_map));
|
||||
if (!llama_mmap::SUPPORTED) {
|
||||
use_mmap = false;
|
||||
}
|
||||
if (use_mmap && alignment_prevents_mmap()) {
|
||||
fprintf(stderr, "llama.cpp: can't use mmap because tensors are not aligned; convert to new format to avoid this\n");
|
||||
use_mmap = false;
|
||||
}
|
||||
this->use_mmap = use_mmap;
|
||||
for (llama_load_tensor & lt : tensors_map.tensors) {
|
||||
lt.calc_all();
|
||||
}
|
||||
}
|
||||
|
||||
bool alignment_prevents_mmap() {
|
||||
for (const llama_load_tensor & lt : tensors_map.tensors) {
|
||||
for (const llama_load_tensor_shard & shard : lt.shards) {
|
||||
if (shard.file_off & 3) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
uint32_t guess_n_parts() const {
|
||||
auto it = tensors_map.name_to_idx.find("tok_embeddings.weight");
|
||||
if (it == tensors_map.name_to_idx.end()) {
|
||||
throw std::runtime_error(std::string("missing tok_embeddings.weight"));
|
||||
}
|
||||
const llama_load_tensor & lt = tensors_map.tensors.at(it->second);
|
||||
return file_loaders.at(0)->hparams.n_embd / lt.shards.at(0).ne.at(0);
|
||||
}
|
||||
|
||||
void calc_sizes(size_t * ctx_size_p, size_t * mmapped_size_p) const {
|
||||
|
@ -774,7 +645,7 @@ struct llama_model_loader {
|
|||
}
|
||||
|
||||
if (use_mmap) {
|
||||
mapping.reset(new llama_mmap(&file_loaders.at(0)->file, prefetch_size, ggml_is_numa()));
|
||||
mapping.reset(new llama_mmap(&file_loader->file, prefetch_size, ggml_is_numa()));
|
||||
if (lmlock) {
|
||||
lmlock->init(mapping->addr);
|
||||
}
|
||||
|
@ -830,45 +701,13 @@ struct llama_model_loader {
|
|||
|
||||
void load_data_for(llama_load_tensor & lt) {
|
||||
if (use_mmap) {
|
||||
LLAMA_ASSERT(lt.shards.size() == 1);
|
||||
lt.data = (uint8_t *) mapping->addr + lt.shards.at(0).file_off;
|
||||
} else if (lt.split_type == SPLIT_NONE) {
|
||||
llama_file & file = file_loaders.at(lt.shards.at(0).file_idx)->file;
|
||||
file.seek(lt.shards.at(0).file_off, SEEK_SET);
|
||||
lt.data = (uint8_t *) mapping->addr + lt.file_off;
|
||||
} else {
|
||||
llama_file & file = file_loader->file;
|
||||
file.seek(lt.file_off, SEEK_SET);
|
||||
file.read_raw(lt.data, lt.size);
|
||||
} else if (lt.split_type == SPLIT_BY_ROWS) {
|
||||
size_t offset = 0;
|
||||
for (llama_load_tensor_shard & shard : lt.shards) {
|
||||
llama_file & file = file_loaders.at(shard.file_idx)->file;
|
||||
file.seek(shard.file_off, SEEK_SET);
|
||||
file.read_raw(lt.data + offset, shard.size);
|
||||
offset += shard.size;
|
||||
}
|
||||
LLAMA_ASSERT(offset == lt.size);
|
||||
} else if (lt.split_type == SPLIT_BY_COLUMNS) {
|
||||
// Let's load the data into temporary buffers to ensure the OS performs large loads.
|
||||
std::vector<llama_buffer> tmp_bufs(lt.shards.size());
|
||||
for (size_t i = 0; i < lt.shards.size(); i++) {
|
||||
llama_load_tensor_shard & shard = lt.shards.at(i);
|
||||
llama_file & file = file_loaders.at(shard.file_idx)->file;
|
||||
file.seek(shard.file_off, SEEK_SET);
|
||||
tmp_bufs.at(i).resize(shard.size);
|
||||
file.read_raw(tmp_bufs.at(i).addr, shard.size);
|
||||
}
|
||||
// Then reshape.
|
||||
size_t num_rows = lt.ne.at(1);
|
||||
size_t per_shard_row_size = lt.shards.at(0).size / num_rows;
|
||||
size_t out_offset = 0;
|
||||
for (size_t row = 0; row < num_rows; row++) {
|
||||
for (llama_buffer & tmp_buf : tmp_bufs) {
|
||||
memcpy(lt.data + out_offset,
|
||||
tmp_buf.addr + row * per_shard_row_size,
|
||||
per_shard_row_size);
|
||||
out_offset += per_shard_row_size;
|
||||
}
|
||||
}
|
||||
LLAMA_ASSERT(out_offset == lt.size);
|
||||
}
|
||||
|
||||
if (0) {
|
||||
print_checksum(lt);
|
||||
}
|
||||
|
@ -1067,12 +906,12 @@ static void llama_model_load_internal(
|
|||
|
||||
model.t_start_us = ggml_time_us();
|
||||
|
||||
std::unique_ptr<llama_model_loader> ml(new llama_model_loader(fname, use_mmap, vocab_only));
|
||||
std::unique_ptr<llama_model_loader> ml(new llama_model_loader(fname, use_mmap));
|
||||
|
||||
vocab = std::move(ml->file_loaders.at(0)->vocab);
|
||||
model.hparams = ml->file_loaders.at(0)->hparams;
|
||||
vocab = std::move(ml->file_loader->vocab);
|
||||
model.hparams = ml->file_loader->hparams;
|
||||
model.n_gpu_layers = n_gpu_layers;
|
||||
llama_file_version file_version = ml->file_loaders.at(0)->file_version;
|
||||
llama_file_version file_version = ml->file_loader->file_version;
|
||||
auto & hparams = model.hparams;
|
||||
|
||||
{
|
||||
|
@ -1106,7 +945,6 @@ static void llama_model_load_internal(
|
|||
fprintf(stderr, "%s: n_rot = %u\n", __func__, hparams.n_rot);
|
||||
fprintf(stderr, "%s: ftype = %u (%s)\n", __func__, hparams.ftype, llama_ftype_name(hparams.ftype));
|
||||
fprintf(stderr, "%s: n_ff = %u\n", __func__, n_ff);
|
||||
fprintf(stderr, "%s: n_parts = %zu\n", __func__, ml->file_loaders.size());
|
||||
fprintf(stderr, "%s: model size = %s\n", __func__, llama_model_type_name(model.type));
|
||||
}
|
||||
|
||||
|
@ -1369,22 +1207,26 @@ static bool llama_model_load(
|
|||
|
||||
// evaluate the transformer
|
||||
//
|
||||
// - lctx: llama context
|
||||
// - tokens: new batch of tokens to process
|
||||
// - n_past: the context size so far
|
||||
// - n_threads: number of threads to use
|
||||
// - cgraph_fname: filename of the exported computation graph
|
||||
// - lctx: llama context
|
||||
// - tokens: new batch of tokens to process
|
||||
// - embd embeddings input
|
||||
// - n_tokens number of tokens
|
||||
// - n_past: the context size so far
|
||||
// - n_threads: number of threads to use
|
||||
//
|
||||
static bool llama_eval_internal(
|
||||
llama_context & lctx,
|
||||
const llama_token * tokens,
|
||||
const int n_tokens,
|
||||
const int n_past,
|
||||
const int n_threads,
|
||||
llama_context & lctx,
|
||||
const llama_token * tokens,
|
||||
const float * embd,
|
||||
const int n_tokens,
|
||||
const int n_past,
|
||||
const int n_threads,
|
||||
const char * cgraph_fname) {
|
||||
|
||||
LLAMA_ASSERT((!tokens && embd) || (tokens && !embd));
|
||||
|
||||
// enforce that the first token is BOS
|
||||
if (n_past == 0 && tokens[0] != llama_token_bos()) {
|
||||
if (tokens && n_past == 0 && tokens[0] != llama_token_bos()) {
|
||||
fprintf(stderr, "%s: first token must be BOS\n", __func__);
|
||||
return false;
|
||||
}
|
||||
|
@ -1424,12 +1266,18 @@ static bool llama_eval_internal(
|
|||
ggml_cgraph gf = {};
|
||||
gf.n_threads = N >= 32 && ggml_cpu_has_blas() && !ggml_cpu_has_gpublas() ? 1 : n_threads;
|
||||
|
||||
struct ggml_tensor * embd = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
|
||||
ggml_set_name(embd, "embd");
|
||||
memcpy(embd->data, tokens, N*ggml_element_size(embd));
|
||||
|
||||
struct ggml_tensor * cur;
|
||||
struct ggml_tensor * inpL = ggml_get_rows(ctx0, model.tok_embeddings, embd);
|
||||
struct ggml_tensor * inpL;
|
||||
|
||||
if (tokens) {
|
||||
struct ggml_tensor * embd = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
|
||||
ggml_set_name(embd, "embd");
|
||||
memcpy(embd->data, tokens, N*ggml_element_size(embd));
|
||||
inpL = ggml_get_rows(ctx0, model.tok_embeddings, embd);
|
||||
} else {
|
||||
inpL = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, N);
|
||||
memcpy(inpL->data, embd, N * n_embd * ggml_element_size(inpL));
|
||||
}
|
||||
|
||||
const int i_gpu_start = n_layer - n_gpu_layers;
|
||||
(void) i_gpu_start;
|
||||
|
@ -2451,9 +2299,8 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
|
|||
nthread = std::thread::hardware_concurrency();
|
||||
}
|
||||
|
||||
std::unique_ptr<llama_model_loader> model_loader(new llama_model_loader(fname_inp, /*use_mmap*/ false,
|
||||
/*vocab_only*/ false));
|
||||
llama_file_saver file_saver(fname_out.c_str(), model_loader->file_loaders.at(0).get(), params->ftype);
|
||||
std::unique_ptr<llama_model_loader> model_loader(new llama_model_loader(fname_inp, /*use_mmap*/ false));
|
||||
llama_file_saver file_saver(fname_out.c_str(), model_loader->file_loader.get(), params->ftype);
|
||||
|
||||
#ifdef GGML_USE_K_QUANTS
|
||||
int n_attention_wv = 0;
|
||||
|
@ -2654,6 +2501,8 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
|
|||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
//
|
||||
// interface implementation
|
||||
//
|
||||
|
@ -2874,7 +2723,7 @@ int llama_apply_lora_from_file_internal(const struct llama_model & model, const
|
|||
|
||||
// create a name -> tensor map of the model to accelerate lookups
|
||||
std::unordered_map<std::string, struct ggml_tensor*> model_tensors;
|
||||
for (auto & kv: model.tensors_by_name) {
|
||||
for (const auto & kv: model.tensors_by_name) {
|
||||
model_tensors.insert(kv);
|
||||
}
|
||||
|
||||
|
@ -2885,7 +2734,7 @@ int llama_apply_lora_from_file_internal(const struct llama_model & model, const
|
|||
llama_buffer base_buf;
|
||||
if (path_base_model) {
|
||||
fprintf(stderr, "%s: loading base model from '%s'\n", __func__, path_base_model);
|
||||
model_loader.reset(new llama_model_loader(path_base_model, /*use_mmap*/ true, /*vocab_only*/ false));
|
||||
model_loader.reset(new llama_model_loader(path_base_model, /*use_mmap*/ true));
|
||||
|
||||
size_t ctx_size;
|
||||
size_t mmapped_size;
|
||||
|
@ -2903,7 +2752,7 @@ int llama_apply_lora_from_file_internal(const struct llama_model & model, const
|
|||
|
||||
// maybe this should in llama_model_loader
|
||||
if (model_loader->use_mmap) {
|
||||
model_loader->mapping.reset(new llama_mmap(&model_loader->file_loaders.at(0)->file, /* prefetch */ 0, ggml_is_numa()));
|
||||
model_loader->mapping.reset(new llama_mmap(&model_loader->file_loader->file, /* prefetch */ 0, ggml_is_numa()));
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -2964,7 +2813,7 @@ int llama_apply_lora_from_file_internal(const struct llama_model & model, const
|
|||
return false;
|
||||
}
|
||||
}
|
||||
ggml_tensor* lora_tensor;
|
||||
ggml_tensor * lora_tensor;
|
||||
if (n_dims == 2) {
|
||||
lora_tensor = ggml_new_tensor_2d(lora_ctx, wtype, ne[0], ne[1]);
|
||||
}
|
||||
|
@ -2972,6 +2821,7 @@ int llama_apply_lora_from_file_internal(const struct llama_model & model, const
|
|||
fprintf(stderr, "%s: unsupported tensor dimension %d\n", __func__, n_dims);
|
||||
return 1;
|
||||
}
|
||||
ggml_set_name(lora_tensor, "lora_tensor");
|
||||
|
||||
// load tensor data
|
||||
size_t offset = fin.tellg();
|
||||
|
@ -2987,6 +2837,21 @@ int llama_apply_lora_from_file_internal(const struct llama_model & model, const
|
|||
lora_tensors.find(base_name + ".loraB") != lora_tensors.end()) {
|
||||
|
||||
ggml_tensor * dest_t = model_tensors[base_name];
|
||||
|
||||
offload_func_t offload_func = llama_nop;
|
||||
offload_func_t offload_func_force_inplace = llama_nop;
|
||||
|
||||
#ifdef GGML_USE_CUBLAS
|
||||
if (dest_t->backend == GGML_BACKEND_GPU || dest_t->backend == GGML_BACKEND_GPU_SPLIT) {
|
||||
if (dest_t->type != GGML_TYPE_F16) {
|
||||
throw std::runtime_error(format(
|
||||
"%s: error: the simultaneous use of LoRAs and GPU acceleration is only supported for f16 models", __func__));
|
||||
}
|
||||
offload_func = ggml_cuda_assign_buffers;
|
||||
offload_func_force_inplace = ggml_cuda_assign_buffers_force_inplace;
|
||||
}
|
||||
#endif // GGML_USE_CUBLAS
|
||||
|
||||
ggml_tensor * base_t;
|
||||
if (model_loader) {
|
||||
// load from base model
|
||||
|
@ -3014,7 +2879,12 @@ int llama_apply_lora_from_file_internal(const struct llama_model & model, const
|
|||
}
|
||||
|
||||
ggml_tensor * loraA = lora_tensors[base_name + ".loraA"];
|
||||
GGML_ASSERT(loraA->type == GGML_TYPE_F32);
|
||||
ggml_set_name(loraA, "loraA");
|
||||
|
||||
ggml_tensor * loraB = lora_tensors[base_name + ".loraB"];
|
||||
GGML_ASSERT(loraB->type == GGML_TYPE_F32);
|
||||
ggml_set_name(loraB, "loraB");
|
||||
|
||||
if (base_t->ne[0] != loraA->ne[1] || base_t->ne[1] != loraB->ne[1]) {
|
||||
fprintf(stderr, "%s: incompatible tensor dimensions (%" PRId64 " and %" PRId64 ");"
|
||||
|
@ -3024,19 +2894,32 @@ int llama_apply_lora_from_file_internal(const struct llama_model & model, const
|
|||
|
||||
// w = w + BA*s
|
||||
ggml_tensor * BA = ggml_mul_mat(lora_ctx, loraA, loraB);
|
||||
offload_func(BA);
|
||||
ggml_set_name(BA, "BA");
|
||||
|
||||
if (scaling != 1.0f) {
|
||||
ggml_tensor * scale_tensor = ggml_new_f32(lora_ctx, scaling);
|
||||
ggml_set_name(scale_tensor, "scale_tensor");
|
||||
|
||||
BA = ggml_scale_inplace(lora_ctx, BA, scale_tensor);
|
||||
offload_func(BA);
|
||||
ggml_set_name(BA, "BA_scaled");
|
||||
}
|
||||
|
||||
ggml_tensor * r;
|
||||
if (base_t == dest_t) {
|
||||
r = ggml_add_inplace(lora_ctx, dest_t, BA);
|
||||
offload_func_force_inplace(r);
|
||||
ggml_set_name(r, "r_add_inplace");
|
||||
}
|
||||
else {
|
||||
r = ggml_add(lora_ctx, base_t, BA);
|
||||
offload_func(r);
|
||||
ggml_set_name(r, "r_add");
|
||||
|
||||
r = ggml_cpy(lora_ctx, r, dest_t);
|
||||
offload_func(r);
|
||||
ggml_set_name(r, "r_cpy");
|
||||
}
|
||||
|
||||
struct ggml_cgraph gf = ggml_build_forward(r);
|
||||
|
@ -3421,7 +3304,29 @@ int llama_eval(
|
|||
int n_tokens,
|
||||
int n_past,
|
||||
int n_threads) {
|
||||
if (!llama_eval_internal(*ctx, tokens, n_tokens, n_past, n_threads, nullptr)) {
|
||||
if (!llama_eval_internal(*ctx, tokens, nullptr, n_tokens, n_past, n_threads, nullptr)) {
|
||||
fprintf(stderr, "%s: failed to eval\n", __func__);
|
||||
return 1;
|
||||
}
|
||||
|
||||
// get a more accurate load time, upon first eval
|
||||
// TODO: fix this
|
||||
if (!ctx->has_evaluated_once) {
|
||||
ctx->t_load_us = ggml_time_us() - ctx->t_start_us;
|
||||
ctx->has_evaluated_once = true;
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
|
||||
int llama_eval_embd(
|
||||
struct llama_context * ctx,
|
||||
const float * embd,
|
||||
int n_tokens,
|
||||
int n_past,
|
||||
int n_threads) {
|
||||
if (!llama_eval_internal(*ctx, nullptr, embd, n_tokens, n_past, n_threads, nullptr)) {
|
||||
fprintf(stderr, "%s: failed to eval\n", __func__);
|
||||
return 1;
|
||||
}
|
||||
|
@ -3442,7 +3347,7 @@ int llama_eval_export(struct llama_context * ctx, const char * fname) {
|
|||
|
||||
const std::vector<llama_token> tmp(n_batch, llama_token_bos());
|
||||
|
||||
if (!llama_eval_internal(*ctx, tmp.data(), tmp.size(), n_ctx, 1, fname)) {
|
||||
if (!llama_eval_internal(*ctx, tmp.data(), nullptr, tmp.size(), n_ctx, 1, fname)) {
|
||||
fprintf(stderr, "%s: failed to eval\n", __func__);
|
||||
return 1;
|
||||
}
|
||||
|
|
8
llama.h
8
llama.h
|
@ -226,6 +226,14 @@ extern "C" {
|
|||
int n_past,
|
||||
int n_threads);
|
||||
|
||||
// Same as llama_eval, but use float matrix input directly.
|
||||
LLAMA_API int llama_eval_embd(
|
||||
struct llama_context * ctx,
|
||||
const float * embd,
|
||||
int n_tokens,
|
||||
int n_past,
|
||||
int n_threads);
|
||||
|
||||
// Export a static computation graph for context of 511 and batch size of 1
|
||||
// NOTE: since this functionality is mostly for debugging and demonstration purposes, we hardcode these
|
||||
// parameters here to keep things simple
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue