llama : second attempt to refactor vision API
This commit is contained in:
parent
2a458d1a9d
commit
0a81051ae2
20 changed files with 695 additions and 145 deletions
|
@ -1403,7 +1403,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
|||
[](common_params & params, const std::string & value) {
|
||||
params.image.emplace_back(value);
|
||||
}
|
||||
).set_examples({LLAMA_EXAMPLE_LLAVA}));
|
||||
).set_examples({LLAMA_EXAMPLE_LLAVA, LLAMA_EXAMPLE_VISION}));
|
||||
if (llama_supports_rpc()) {
|
||||
add_opt(common_arg(
|
||||
{"--rpc"}, "SERVERS",
|
||||
|
|
|
@ -79,6 +79,7 @@ enum llama_example {
|
|||
LLAMA_EXAMPLE_LOOKUP,
|
||||
LLAMA_EXAMPLE_PARALLEL,
|
||||
LLAMA_EXAMPLE_TTS,
|
||||
LLAMA_EXAMPLE_VISION,
|
||||
|
||||
LLAMA_EXAMPLE_COUNT,
|
||||
};
|
||||
|
|
|
@ -53,6 +53,7 @@ else()
|
|||
add_subdirectory(tokenize)
|
||||
add_subdirectory(tts)
|
||||
add_subdirectory(gen-docs)
|
||||
add_subdirectory(vision)
|
||||
if (NOT GGML_BACKEND_DL)
|
||||
# these examples use the backends directly and cannot be built with dynamic loading
|
||||
add_subdirectory(convert-llama2c-to-ggml)
|
||||
|
|
5
examples/vision/CMakeLists.txt
Normal file
5
examples/vision/CMakeLists.txt
Normal file
|
@ -0,0 +1,5 @@
|
|||
set(TARGET llama-vision)
|
||||
add_executable(${TARGET} vision.cpp)
|
||||
install(TARGETS ${TARGET} RUNTIME)
|
||||
target_link_libraries(${TARGET} PRIVATE llama common ${CMAKE_THREAD_LIBS_INIT})
|
||||
target_compile_features(${TARGET} PRIVATE cxx_std_17)
|
3
examples/vision/README.md
Normal file
3
examples/vision/README.md
Normal file
|
@ -0,0 +1,3 @@
|
|||
# llama.cpp/example/simple-vision
|
||||
|
||||
Minimal demo for vision API
|
211
examples/vision/vision.cpp
Normal file
211
examples/vision/vision.cpp
Normal file
|
@ -0,0 +1,211 @@
|
|||
#include "llama.h"
|
||||
#include "common.h"
|
||||
#include "arg.h"
|
||||
#include "log.h"
|
||||
#include "sampling.h"
|
||||
#include <cstdio>
|
||||
#include <cstring>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <fstream>
|
||||
|
||||
#define STB_IMAGE_IMPLEMENTATION
|
||||
#include "stb_image.h"
|
||||
|
||||
static void print_usage(int, char ** argv) {
|
||||
printf("\nexample usage:\n");
|
||||
printf("\n %s -m model.gguf [-n n_predict] [-ngl n_gpu_layers] [--image img_path] [-p prompt]\n", argv[0]);
|
||||
printf("\n");
|
||||
}
|
||||
|
||||
static llama_vision_bitmap * load_image_from_file(const char * fname) {
|
||||
std::ifstream file(fname, std::ios::binary);
|
||||
if (!file) {
|
||||
throw std::runtime_error("Unable to open file");
|
||||
}
|
||||
std::vector<char> image_bytes = std::vector<char>(
|
||||
std::istreambuf_iterator<char>(file),
|
||||
std::istreambuf_iterator<char>());
|
||||
// decode image to byte array
|
||||
int nx, ny, nc;
|
||||
auto * bytes = (unsigned char *) image_bytes.data();
|
||||
auto * img = stbi_load_from_memory(bytes, image_bytes.size(), &nx, &ny, &nc, 3);
|
||||
if (!img) {
|
||||
throw std::runtime_error("failed to decode image bytes");
|
||||
}
|
||||
// printf("nx=%d ny=%d nc=%d\n", nx, ny, nc);
|
||||
// GGML_ASSERT(nc == 3);
|
||||
// for (int y = 0; y < ny; y++) {
|
||||
// for (int x = 0; x < nx; x++) {
|
||||
// unsigned char * pix = img + x*nc + y*nc*nx;
|
||||
// printf("%02x%02x%02x ", pix[0], pix[1], pix[2]);
|
||||
// }
|
||||
// printf("\n");
|
||||
// }
|
||||
// printf("\n");
|
||||
llama_vision_bitmap * result = llama_vision_bitmap_init(nx, ny);
|
||||
memcpy(result->data, img, nx*ny*3);
|
||||
stbi_image_free(img);
|
||||
return result;
|
||||
}
|
||||
|
||||
// split string by a `std::string delim` instead of `char delim`
|
||||
static std::vector<std::string> string_split(std::string s, const std::string & delimiter) {
|
||||
std::vector<std::string> tokens;
|
||||
size_t pos = 0;
|
||||
std::string token;
|
||||
while ((pos = s.find(delimiter)) != std::string::npos) {
|
||||
token = s.substr(0, pos);
|
||||
tokens.push_back(token);
|
||||
s.erase(0, pos + delimiter.length());
|
||||
}
|
||||
tokens.push_back(s);
|
||||
return tokens;
|
||||
}
|
||||
|
||||
struct tokenized_part {
|
||||
llama_tokens tokens;
|
||||
bool is_image;
|
||||
};
|
||||
|
||||
// TODO: this function is hacky, need to be improved
|
||||
// static const llama_token TOKEN_IMG_PLACEMENT = -1000;
|
||||
static const std::string IMG_PLACEMENT = "<img_placement>";
|
||||
static std::vector<tokenized_part> tokenize_with_img_placement(
|
||||
const llama_vocab * vocab,
|
||||
const std::string & text,
|
||||
bool add_special,
|
||||
bool parse_special) {
|
||||
std::vector<std::string> parts = string_split(text, IMG_PLACEMENT);
|
||||
std::vector<tokenized_part> output;
|
||||
for (const auto & part : parts) {
|
||||
//printf("tokenizing part: %s\n", part.c_str());
|
||||
bool add_bos = &parts.front() == ∂
|
||||
auto tokens = common_tokenize(vocab, part, add_special && add_bos, parse_special);
|
||||
if (tokens.empty()) {
|
||||
continue;
|
||||
}
|
||||
output.push_back({std::move(tokens), false});
|
||||
if (&parts.back() != &part) {
|
||||
// add image token to middle of 2 parts
|
||||
output.push_back({{}, true});
|
||||
}
|
||||
}
|
||||
return output;
|
||||
}
|
||||
|
||||
int main(int argc, char ** argv) {
|
||||
common_params params;
|
||||
|
||||
// default prompt for llava 1.5
|
||||
params.prompt = "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.\n"
|
||||
"USER:<img_placement>\nwhat did you see?\nASSISTANT:";
|
||||
params.n_predict = 64;
|
||||
params.n_batch = 2048;
|
||||
params.n_ubatch = 1024;
|
||||
params.n_gpu_layers = 99;
|
||||
|
||||
if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_VISION, print_usage)) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
common_init();
|
||||
common_init_result llama_init = common_init_from_params(params);
|
||||
llama_context * ctx = llama_init.context.get();
|
||||
const llama_model * model = llama_init.model.get();
|
||||
const llama_vocab * vocab = llama_model_get_vocab(model);
|
||||
|
||||
struct common_sampler * smpl = common_sampler_init(model, params.sampling);
|
||||
|
||||
llama_batch batch = llama_batch_init(llama_n_batch(ctx), 0, 1);
|
||||
int n_past = 0;
|
||||
int n_prompt = 0;
|
||||
|
||||
// process image
|
||||
llama_vision_patches * img_patches = nullptr;
|
||||
{
|
||||
const char * img_path = params.image[0].c_str();
|
||||
if (params.image[0].empty()) {
|
||||
LOG_ERR("no image path provided\n");
|
||||
return 1;
|
||||
}
|
||||
llama_vision_bitmap * img = load_image_from_file(img_path);
|
||||
LOG_INF("loaded image %s, size = %d x %d\n", img_path, img->nx, img->ny);
|
||||
img_patches = llama_vision_patches_init(ctx, img);
|
||||
if (!img_patches) {
|
||||
LOG_ERR("failed to create image patches\n");
|
||||
return 1;
|
||||
}
|
||||
if (llama_vision_encode(ctx, img_patches)) {
|
||||
LOG_ERR("failed to encode image\n");
|
||||
return 1;
|
||||
}
|
||||
LOG_INF("encoded image\n");
|
||||
}
|
||||
|
||||
// process prompt
|
||||
{
|
||||
std::vector<tokenized_part> parts = tokenize_with_img_placement(vocab, params.prompt, true, true);
|
||||
for (const tokenized_part & part : parts) {
|
||||
if (!part.is_image) {
|
||||
for (const llama_token & token : part.tokens) {
|
||||
//LOG_INF("%d -> %s\n", token, common_token_to_piece(ctx, token).c_str());
|
||||
common_batch_add(batch, token, n_past++, {0}, &part == &parts.back());
|
||||
}
|
||||
LOG_INF("eval text batch (%d tokens)\n", batch.n_tokens);
|
||||
if (llama_decode(ctx, batch)) {
|
||||
LOG_ERR("failed to decode text prompt\n");
|
||||
return 1;
|
||||
}
|
||||
} else {
|
||||
auto * img_embd = llama_vision_get_output_tensor(ctx);
|
||||
// std::vector<float> output_debug(ggml_nelements(img_embd));
|
||||
// ggml_backend_tensor_get(img_embd, output_debug.data(), 0, ggml_nbytes(img_embd));
|
||||
// for (int row = 0; row < 10; row++) {
|
||||
// int off = row * img_embd->ne[0];
|
||||
// printf("... %f %f %f\n", output_debug[off], output_debug[off+1], output_debug[off+2]);
|
||||
// }
|
||||
// exit(1);
|
||||
llama_batch batch_img = llama_batch_get_one_from_tensor(img_embd, n_past, 0);
|
||||
n_past += batch_img.n_tokens;
|
||||
LOG_INF("eval image batch (%d embeddings)\n", batch_img.n_tokens);
|
||||
if (llama_decode(ctx, batch_img)) {
|
||||
LOG_ERR("failed to decode image prompt\n");
|
||||
return 1;
|
||||
}
|
||||
llama_batch_free(batch_img);
|
||||
}
|
||||
}
|
||||
n_prompt = n_past;
|
||||
LOG_INF("prompt processed, %d tokens\n", n_prompt);
|
||||
}
|
||||
|
||||
// generate response
|
||||
while (true){
|
||||
int n_generated = n_past - n_prompt;
|
||||
if (n_generated > params.n_predict) {
|
||||
printf("\n");
|
||||
break;
|
||||
}
|
||||
|
||||
llama_token token_id = common_sampler_sample(smpl, ctx, -1);
|
||||
common_sampler_accept(smpl, token_id, true);
|
||||
printf("%s", common_token_to_piece(ctx, token_id).c_str());
|
||||
fflush(stdout);
|
||||
|
||||
if (llama_vocab_is_eog(vocab, token_id)) {
|
||||
printf("\n");
|
||||
break;
|
||||
}
|
||||
|
||||
// eval the token
|
||||
common_batch_clear(batch);
|
||||
common_batch_add(batch, token_id, n_past++, {0}, true);
|
||||
if (llama_decode(ctx, batch)) {
|
||||
LOG_ERR("failed to decode token\n");
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
|
@ -229,6 +229,8 @@ extern "C" {
|
|||
bool sorted;
|
||||
} llama_token_data_array;
|
||||
|
||||
struct llama_vision_patches;
|
||||
|
||||
// represent an RGB image
|
||||
// size of data must be equal to 3*nx*ny
|
||||
typedef struct llama_vision_bitmap {
|
||||
|
@ -237,8 +239,6 @@ extern "C" {
|
|||
unsigned char * data;
|
||||
} llama_vision_bitmap;
|
||||
|
||||
struct llama_vision_patches;
|
||||
|
||||
typedef bool (*llama_progress_callback)(float progress, void * user_data);
|
||||
|
||||
// Input data for llama_decode
|
||||
|
@ -263,6 +263,8 @@ extern "C" {
|
|||
int32_t * n_seq_id;
|
||||
llama_seq_id ** seq_id;
|
||||
int8_t * logits; // TODO: rename this to "output"
|
||||
|
||||
struct ggml_tensor * embd_tensor;
|
||||
} llama_batch;
|
||||
|
||||
enum llama_model_kv_override_type {
|
||||
|
@ -854,6 +856,10 @@ extern "C" {
|
|||
int32_t embd,
|
||||
int32_t n_seq_max);
|
||||
|
||||
// Allocates a batch based on a tensor, only used by vision API for now
|
||||
// Unlike llama_batch_get_one, this will need to be freed after use
|
||||
LLAMA_API struct llama_batch llama_batch_get_one_from_tensor(struct ggml_tensor * tensor, int32_t p0, int32_t seq_id);
|
||||
|
||||
// Frees a batch of tokens allocated with llama_batch_init()
|
||||
LLAMA_API void llama_batch_free(struct llama_batch batch);
|
||||
|
||||
|
@ -1272,6 +1278,22 @@ extern "C" {
|
|||
// TODO: extend in the future
|
||||
//LLAMA_API void llama_decode_with_sampler(struct llama_context * ctx, struct llama_sampler * smpl, struct llama_batch batch, ...);
|
||||
|
||||
//
|
||||
// Vision API
|
||||
//
|
||||
|
||||
// Container for RGB bitmap
|
||||
LLAMA_API struct llama_vision_bitmap * llama_vision_bitmap_init(uint32_t nx, uint32_t ny);
|
||||
LLAMA_API void llama_vision_bitmap_free(struct llama_vision_bitmap * bmp);
|
||||
|
||||
// Create patches from the RGB bitmap
|
||||
LLAMA_API struct llama_vision_patches * llama_vision_patches_init(struct llama_context * ctx, llama_vision_bitmap * bmp);
|
||||
LLAMA_API void llama_vision_patches_free(struct llama_vision_patches * p);
|
||||
|
||||
// Encode patches into embeddings
|
||||
LLAMA_API int32_t llama_vision_encode(struct llama_context * ctx, struct llama_vision_patches * p);
|
||||
LLAMA_API struct ggml_tensor * llama_vision_get_output_tensor(llama_context * ctx);
|
||||
|
||||
//
|
||||
// Model split
|
||||
//
|
||||
|
|
|
@ -24,6 +24,7 @@ add_library(llama
|
|||
llama-quant.cpp
|
||||
llama-sampling.cpp
|
||||
llama-vocab.cpp
|
||||
llama-vision.cpp
|
||||
unicode.h
|
||||
unicode.cpp
|
||||
unicode-data.cpp
|
||||
|
|
|
@ -65,6 +65,11 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
|
|||
{ LLM_ARCH_UNKNOWN, "(unknown)" },
|
||||
};
|
||||
|
||||
static const std::map<vision_arch, const char *> VISION_ARCH_NAMES = {
|
||||
{ VISION_ARCH_LLAVA, "llava" },
|
||||
{ VISION_ARCH_UNKNOWN, "(unknown)" },
|
||||
};
|
||||
|
||||
static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
|
||||
{ LLM_KV_GENERAL_TYPE, "general.type" },
|
||||
{ LLM_KV_GENERAL_ARCHITECTURE, "general.architecture" },
|
||||
|
@ -189,6 +194,27 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
|
|||
{ LLM_KV_ADAPTER_TYPE, "adapter.type" },
|
||||
{ LLM_KV_ADAPTER_LORA_ALPHA, "adapter.lora.alpha" },
|
||||
|
||||
{ LLM_KV_VISION_TYPE, "vision.type" },
|
||||
{ LLM_KV_VISION_IMAGE_SIZE, "vision.image_size" },
|
||||
{ LLM_KV_VISION_PATCH_SIZE, "vision.patch_size" },
|
||||
{ LLM_KV_VISION_IMAGE_MEAN, "vision.image_mean" },
|
||||
{ LLM_KV_VISION_IMAGE_STD, "vision.image_std" },
|
||||
{ LLM_KV_VISION_CLIP_ARCHITECTURE, "vision.clip.architecture" },
|
||||
{ LLM_KV_VISION_CLIP_CONTEXT_LENGTH, "vision.clip.context_length" },
|
||||
{ LLM_KV_VISION_CLIP_EMBEDDING_LENGTH, "vision.clip.embedding_length" },
|
||||
{ LLM_KV_VISION_CLIP_BLOCK_COUNT, "vision.clip.block_count" },
|
||||
{ LLM_KV_VISION_CLIP_FEED_FORWARD_LENGTH, "vision.clip.feed_forward_length" },
|
||||
{ LLM_KV_VISION_CLIP_PROJECTION_TYPE, "vision.clip.projection_type" },
|
||||
{ LLM_KV_VISION_CLIP_PROJECTION_DIM, "vision.clip.projection_dim" },
|
||||
{ LLM_KV_VISION_CLIP_USE_GELU, "vision.clip.use_gelu" },
|
||||
{ LLM_KV_VISION_CLIP_MAX_POS_EMBD, "vision.clip.max_position_embeddings" },
|
||||
{ LLM_KV_VISION_CLIP_MAX_SLICES, "vision.clip.max_slices" },
|
||||
{ LLM_KV_VISION_CLIP_PROJECTOR_TYPE, "vision.clip.projector_type" },
|
||||
{ LLM_KV_VISION_CLIP_SELECT_LAYER, "vision.clip.select_layer" },
|
||||
{ LLM_KV_VISION_CLIP_PATCH_MERGE_TYPE, "vision.clip.patch_merge_type" },
|
||||
{ LLM_KV_VISION_CLIP_HEAD_COUNT, "vision.clip.attention.head_count" },
|
||||
{ LLM_KV_VISION_CLIP_LAYERNORM_EPS, "vision.clip.attention.layer_norm_epsilon" },
|
||||
|
||||
// deprecated
|
||||
{ LLM_KV_TOKENIZER_PREFIX_ID, "tokenizer.ggml.prefix_token_id" },
|
||||
{ LLM_KV_TOKENIZER_SUFFIX_ID, "tokenizer.ggml.suffix_token_id" },
|
||||
|
@ -1300,6 +1326,28 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
|
|||
},
|
||||
};
|
||||
|
||||
static const std::map<vision_arch, std::map<vision_tensor, const char *>> VISION_TENSOR_NAMES = {
|
||||
{
|
||||
VISION_ARCH_LLAVA,
|
||||
{
|
||||
{ VISION_TENSOR_MMPROJ, "v.mmproj_%d" },
|
||||
{ VISION_TENSOR_ENC_EMBD_CLS, "v.enc.embd.cls" },
|
||||
{ VISION_TENSOR_ENC_EMBD_PATCH, "v.enc.embd.patch" },
|
||||
{ VISION_TENSOR_ENC_EMBD_POS, "v.enc.embd.pos" },
|
||||
{ VISION_TENSOR_ENC_ATTN_Q, "v.enc.blk.%d.attn_q" },
|
||||
{ VISION_TENSOR_ENC_ATTN_K, "v.enc.blk.%d.attn_k" },
|
||||
{ VISION_TENSOR_ENC_ATTN_V, "v.enc.blk.%d.attn_v" },
|
||||
{ VISION_TENSOR_ENC_INPUT_NORM, "v.enc.blk.%d.input_norm" },
|
||||
{ VISION_TENSOR_ENC_OUTPUT, "v.enc.blk.%d.output" },
|
||||
{ VISION_TENSOR_ENC_OUTPUT_NORM, "v.enc.blk.%d.output_norm" },
|
||||
{ VISION_TENSOR_ENC_FFN_UP, "v.enc.blk.%d.ffn_up" },
|
||||
{ VISION_TENSOR_ENC_FFN_DOWN, "v.enc.blk.%d.ffn_down" },
|
||||
{ VISION_TENSOR_PRE_NORM, "v.pre_norm" },
|
||||
{ VISION_TENSOR_POST_NORM, "v.post_norm" },
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
|
||||
{LLM_TENSOR_TOKEN_EMBD, {LLM_TENSOR_LAYER_INPUT, GGML_OP_GET_ROWS}},
|
||||
{LLM_TENSOR_POS_EMBD, {LLM_TENSOR_LAYER_INPUT, GGML_OP_GET_ROWS}},
|
||||
|
@ -1449,7 +1497,8 @@ std::string LLM_KV::operator()(llm_kv kv) const {
|
|||
return ::format(LLM_KV_NAMES.at(kv), LLM_ARCH_NAMES.at(arch));
|
||||
}
|
||||
|
||||
std::string LLM_TN_IMPL::str() const {
|
||||
template<>
|
||||
std::string BASE_TN_IMPL<llm_arch, llm_tensor>::str() const {
|
||||
if (LLM_TENSOR_NAMES.at(arch).find(tensor) == LLM_TENSOR_NAMES.at(arch).end()) {
|
||||
return "__missing__";
|
||||
}
|
||||
|
@ -1464,6 +1513,22 @@ std::string LLM_TN_IMPL::str() const {
|
|||
return name;
|
||||
}
|
||||
|
||||
template<>
|
||||
std::string BASE_TN_IMPL<vision_arch, vision_tensor>::str() const {
|
||||
if (VISION_TENSOR_NAMES.at(arch).find(tensor) == VISION_TENSOR_NAMES.at(arch).end()) {
|
||||
return "__missing__";
|
||||
}
|
||||
|
||||
std::string name = ::format(VISION_TENSOR_NAMES.at(arch).at(tensor), bid, xid);
|
||||
|
||||
if (suffix != nullptr) {
|
||||
name += ".";
|
||||
name += suffix;
|
||||
}
|
||||
|
||||
return name;
|
||||
}
|
||||
|
||||
const char * llm_arch_name(llm_arch arch) {
|
||||
auto it = LLM_ARCH_NAMES.find(arch);
|
||||
if (it == LLM_ARCH_NAMES.end()) {
|
||||
|
@ -1482,6 +1547,16 @@ llm_arch llm_arch_from_string(const std::string & name) {
|
|||
return LLM_ARCH_UNKNOWN;
|
||||
}
|
||||
|
||||
vision_arch vision_arch_from_string(const std::string & name) {
|
||||
for (const auto & kv : VISION_ARCH_NAMES) { // NOLINT
|
||||
if (kv.second == name) {
|
||||
return kv.first;
|
||||
}
|
||||
}
|
||||
|
||||
return VISION_ARCH_UNKNOWN;
|
||||
}
|
||||
|
||||
const llm_tensor_info & llm_tensor_info_for(llm_tensor tensor) {
|
||||
return LLM_TENSOR_INFOS.at(tensor);
|
||||
}
|
||||
|
|
|
@ -69,6 +69,11 @@ enum llm_arch {
|
|||
LLM_ARCH_UNKNOWN,
|
||||
};
|
||||
|
||||
enum vision_arch {
|
||||
VISION_ARCH_UNKNOWN,
|
||||
VISION_ARCH_LLAVA,
|
||||
};
|
||||
|
||||
enum llm_kv {
|
||||
LLM_KV_GENERAL_TYPE,
|
||||
LLM_KV_GENERAL_ARCHITECTURE,
|
||||
|
@ -193,6 +198,27 @@ enum llm_kv {
|
|||
LLM_KV_CONVNEXT_EMBEDDING_LENGTH,
|
||||
LLM_KV_CONVNEXT_BLOCK_COUNT,
|
||||
|
||||
LLM_KV_VISION_TYPE,
|
||||
LLM_KV_VISION_IMAGE_SIZE,
|
||||
LLM_KV_VISION_PATCH_SIZE,
|
||||
LLM_KV_VISION_IMAGE_MEAN,
|
||||
LLM_KV_VISION_IMAGE_STD,
|
||||
LLM_KV_VISION_CLIP_ARCHITECTURE,
|
||||
LLM_KV_VISION_CLIP_CONTEXT_LENGTH,
|
||||
LLM_KV_VISION_CLIP_EMBEDDING_LENGTH,
|
||||
LLM_KV_VISION_CLIP_BLOCK_COUNT,
|
||||
LLM_KV_VISION_CLIP_FEED_FORWARD_LENGTH,
|
||||
LLM_KV_VISION_CLIP_PROJECTION_TYPE,
|
||||
LLM_KV_VISION_CLIP_PROJECTION_DIM,
|
||||
LLM_KV_VISION_CLIP_USE_GELU,
|
||||
LLM_KV_VISION_CLIP_MAX_POS_EMBD,
|
||||
LLM_KV_VISION_CLIP_MAX_SLICES,
|
||||
LLM_KV_VISION_CLIP_PROJECTOR_TYPE,
|
||||
LLM_KV_VISION_CLIP_SELECT_LAYER,
|
||||
LLM_KV_VISION_CLIP_PATCH_MERGE_TYPE,
|
||||
LLM_KV_VISION_CLIP_HEAD_COUNT,
|
||||
LLM_KV_VISION_CLIP_LAYERNORM_EPS,
|
||||
|
||||
// deprecated:
|
||||
LLM_KV_TOKENIZER_PREFIX_ID,
|
||||
LLM_KV_TOKENIZER_SUFFIX_ID,
|
||||
|
@ -328,6 +354,23 @@ enum llm_tensor {
|
|||
LLM_TENSOR_POS_NET_ATTN_OUT,
|
||||
};
|
||||
|
||||
enum vision_tensor {
|
||||
VISION_TENSOR_MMPROJ,
|
||||
VISION_TENSOR_ENC_EMBD_CLS,
|
||||
VISION_TENSOR_ENC_EMBD_PATCH,
|
||||
VISION_TENSOR_ENC_EMBD_POS,
|
||||
VISION_TENSOR_ENC_ATTN_Q,
|
||||
VISION_TENSOR_ENC_ATTN_K,
|
||||
VISION_TENSOR_ENC_ATTN_V,
|
||||
VISION_TENSOR_ENC_INPUT_NORM,
|
||||
VISION_TENSOR_ENC_OUTPUT,
|
||||
VISION_TENSOR_ENC_OUTPUT_NORM,
|
||||
VISION_TENSOR_ENC_FFN_UP,
|
||||
VISION_TENSOR_ENC_FFN_DOWN,
|
||||
VISION_TENSOR_PRE_NORM,
|
||||
VISION_TENSOR_POST_NORM,
|
||||
};
|
||||
|
||||
enum llm_tensor_layer {
|
||||
LLM_TENSOR_LAYER_INPUT,
|
||||
LLM_TENSOR_LAYER_REPEATING,
|
||||
|
@ -351,9 +394,10 @@ struct LLM_KV {
|
|||
// std::string name = tn(LLM_TENSOR_TOKEN_EMBD, "bias"); -> "token_embd.bias"
|
||||
// std::string name = tn(LLM_TENSOR_ATTN_NORM, "weight", 3); -> "blk.3.attn_norm.weight"
|
||||
//
|
||||
struct LLM_TN_IMPL {
|
||||
const llm_arch arch;
|
||||
const llm_tensor tensor;
|
||||
template<typename Tname, typename Ttensor>
|
||||
struct BASE_TN_IMPL {
|
||||
const Tname arch;
|
||||
const Ttensor tensor;
|
||||
const char * const suffix;
|
||||
const int bid;
|
||||
const int xid;
|
||||
|
@ -364,15 +408,16 @@ struct LLM_TN_IMPL {
|
|||
return str();
|
||||
}
|
||||
|
||||
friend bool operator==(const std::string & str, const LLM_TN_IMPL & tn) {
|
||||
friend bool operator==(const std::string & str, const BASE_TN_IMPL & tn) {
|
||||
return str == tn.str();
|
||||
}
|
||||
|
||||
friend bool operator!=(const std::string & str, const LLM_TN_IMPL & tn) {
|
||||
friend bool operator!=(const std::string & str, const BASE_TN_IMPL & tn) {
|
||||
return str != tn.str();
|
||||
}
|
||||
};
|
||||
|
||||
using LLM_TN_IMPL = BASE_TN_IMPL<llm_arch, llm_tensor>;
|
||||
struct LLM_TN {
|
||||
LLM_TN(llm_arch arch) : arch(arch) {}
|
||||
|
||||
|
@ -387,6 +432,20 @@ struct LLM_TN {
|
|||
}
|
||||
};
|
||||
|
||||
struct VISION_TN {
|
||||
VISION_TN(vision_arch arch) : arch(arch) {}
|
||||
|
||||
vision_arch arch;
|
||||
|
||||
BASE_TN_IMPL<vision_arch, vision_tensor> operator()(vision_tensor tensor, const char * suffix, int bid = -1, int xid = -1) const {
|
||||
return { arch, tensor, suffix, bid, xid };
|
||||
}
|
||||
|
||||
BASE_TN_IMPL<vision_arch, vision_tensor> operator()(vision_tensor tensor, int bid = -1, int xid = -1) const {
|
||||
return { arch, tensor, nullptr, bid, xid };
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
struct llm_tensor_info {
|
||||
llm_tensor_layer layer;
|
||||
|
@ -397,4 +456,6 @@ const char * llm_arch_name(llm_arch arch);
|
|||
|
||||
llm_arch llm_arch_from_string(const std::string & name);
|
||||
|
||||
vision_arch vision_arch_from_string(const std::string & name);
|
||||
|
||||
const llm_tensor_info & llm_tensor_info_for(llm_tensor tensor);
|
||||
|
|
|
@ -31,6 +31,7 @@ llama_ubatch llama_sbatch::reserve_ubatch(size_t n_ubatch, bool has_embd) {
|
|||
/*n_seq_id =*/ ubatch_n_seq_id.data(),
|
||||
/*seq_id =*/ ubatch_seq_id.data(),
|
||||
/*output =*/ ubatch_output.data(),
|
||||
/*embd_tensor =*/ nullptr,
|
||||
};
|
||||
return ubatch;
|
||||
}
|
||||
|
@ -55,7 +56,9 @@ void llama_sbatch::add_seq_to_ubatch(llama_ubatch & ubatch, llama_sbatch_seq & s
|
|||
} else {
|
||||
ubatch.token = nullptr;
|
||||
}
|
||||
if (batch->embd) {
|
||||
if (batch->embd_tensor) {
|
||||
ubatch.embd_tensor = batch->embd_tensor;
|
||||
} else if (batch->embd) {
|
||||
if (ubatch.equal_seqs) {
|
||||
for (size_t i = 0; i < length; ++i) {
|
||||
memcpy(
|
||||
|
@ -139,7 +142,7 @@ void llama_sbatch::add_seq_to_ubatch(llama_ubatch & ubatch, llama_sbatch_seq & s
|
|||
|
||||
llama_ubatch llama_sbatch::split_simple(size_t n_ubatch) {
|
||||
n_ubatch = n_tokens < n_ubatch ? n_tokens : n_ubatch;
|
||||
llama_ubatch ubatch = reserve_ubatch(n_ubatch, /* has_embd */ batch->embd != nullptr);
|
||||
llama_ubatch ubatch = reserve_ubatch(n_ubatch, /* has_embd */ batch->embd != nullptr || batch->embd_tensor != nullptr);
|
||||
ubatch.equal_seqs = false;
|
||||
if (!seq.empty()) {
|
||||
llama_sbatch_seq & s = seq[0];
|
||||
|
@ -152,7 +155,7 @@ llama_ubatch llama_sbatch::split_simple(size_t n_ubatch) {
|
|||
|
||||
llama_ubatch llama_sbatch::split_equal(size_t n_ubatch) {
|
||||
n_ubatch = n_tokens < n_ubatch ? n_tokens : n_ubatch;
|
||||
llama_ubatch ubatch = reserve_ubatch(n_ubatch, /* has_embd */ batch->embd != nullptr);
|
||||
llama_ubatch ubatch = reserve_ubatch(n_ubatch, /* has_embd */ batch->embd != nullptr || batch->embd_tensor != nullptr);
|
||||
if (!seq.empty()) {
|
||||
size_t length = 0;
|
||||
size_t n_tokens_in_ubatch = 0;
|
||||
|
@ -179,7 +182,7 @@ llama_ubatch llama_sbatch::split_equal(size_t n_ubatch) {
|
|||
|
||||
llama_ubatch llama_sbatch::split_seq(size_t n_ubatch) {
|
||||
n_ubatch = n_tokens < n_ubatch ? n_tokens : n_ubatch;
|
||||
llama_ubatch ubatch = reserve_ubatch(n_ubatch, /* has_embd */ batch->embd != nullptr);
|
||||
llama_ubatch ubatch = reserve_ubatch(n_ubatch, /* has_embd */ batch->embd != nullptr || batch->embd_tensor != nullptr);
|
||||
if (!seq.empty()) {
|
||||
llama_sbatch_seq & s = seq[seq.size() - 1];
|
||||
size_t length = s.length < n_ubatch ? s.length : n_ubatch;
|
||||
|
@ -320,6 +323,7 @@ struct llama_batch llama_batch_get_one(
|
|||
/*n_seq_id =*/ nullptr,
|
||||
/*seq_id =*/ nullptr,
|
||||
/*logits =*/ nullptr,
|
||||
/*embd_tensor =*/ nullptr,
|
||||
};
|
||||
}
|
||||
|
||||
|
@ -332,6 +336,7 @@ struct llama_batch llama_batch_init(int32_t n_tokens_alloc, int32_t embd, int32_
|
|||
/*n_seq_id =*/ nullptr,
|
||||
/*seq_id =*/ nullptr,
|
||||
/*logits =*/ nullptr,
|
||||
/*embd_tensor =*/ nullptr,
|
||||
};
|
||||
|
||||
if (embd) {
|
||||
|
@ -353,6 +358,35 @@ struct llama_batch llama_batch_init(int32_t n_tokens_alloc, int32_t embd, int32_
|
|||
return batch;
|
||||
}
|
||||
|
||||
struct llama_batch llama_batch_get_one_from_tensor(struct ggml_tensor * tensor, int32_t p0, int32_t seq_id) {
|
||||
GGML_ASSERT(tensor->ne[2] == 1 && tensor->ne[3] == 1);
|
||||
int32_t n_tokens = tensor->ne[1];
|
||||
llama_batch batch = {
|
||||
/*n_tokens =*/ n_tokens,
|
||||
/*tokens =*/ nullptr,
|
||||
/*embd =*/ nullptr,
|
||||
/*pos =*/ nullptr,
|
||||
/*n_seq_id =*/ nullptr,
|
||||
/*seq_id =*/ nullptr,
|
||||
/*logits =*/ nullptr,
|
||||
/*embd_tensor =*/ tensor,
|
||||
};
|
||||
batch.pos = (llama_pos *) malloc(sizeof(llama_pos) * n_tokens);
|
||||
batch.n_seq_id = (int32_t *) malloc(sizeof(int32_t) * n_tokens);
|
||||
batch.seq_id = (llama_seq_id **) malloc(sizeof(llama_seq_id *) * (n_tokens + 1));
|
||||
for (int i = 0; i < n_tokens; ++i) {
|
||||
batch.pos [i] = p0 + i;
|
||||
batch.seq_id [i] = (llama_seq_id *) malloc(sizeof(llama_seq_id));
|
||||
batch.seq_id [i][0] = seq_id;
|
||||
batch.n_seq_id[i] = 1;
|
||||
}
|
||||
batch.seq_id[n_tokens] = nullptr;
|
||||
|
||||
batch.logits = (int8_t *) malloc(sizeof(int8_t) * n_tokens);
|
||||
|
||||
return batch;
|
||||
}
|
||||
|
||||
void llama_batch_free(struct llama_batch batch) {
|
||||
if (batch.token) free(batch.token);
|
||||
if (batch.embd) free(batch.embd);
|
||||
|
|
|
@ -21,6 +21,8 @@ struct llama_ubatch {
|
|||
int32_t * n_seq_id; // [n_seqs]
|
||||
llama_seq_id ** seq_id; // [n_seqs]
|
||||
int8_t * output; // [n_tokens]
|
||||
|
||||
struct ggml_tensor * embd_tensor;
|
||||
};
|
||||
|
||||
struct llama_sbatch_seq {
|
||||
|
|
|
@ -73,7 +73,7 @@ void llama_set_inputs(llama_context & lctx, const llama_ubatch & ubatch) {
|
|||
ggml_backend_tensor_set(lctx.inp_tokens, ubatch.token, 0, n_tokens*ggml_element_size(lctx.inp_tokens));
|
||||
}
|
||||
|
||||
if (ubatch.embd) {
|
||||
if (ubatch.embd && !ubatch.embd_tensor) {
|
||||
const int64_t n_embd = hparams.n_embd;
|
||||
const int64_t n_tokens = ubatch.n_tokens;
|
||||
|
||||
|
|
|
@ -6,6 +6,7 @@
|
|||
#include "llama-model.h"
|
||||
#include "llama-kv-cache.h"
|
||||
#include "llama-adapter.h"
|
||||
#include "llama-vision.h"
|
||||
|
||||
#include "ggml-cpp.h"
|
||||
|
||||
|
@ -107,6 +108,9 @@ struct llama_context {
|
|||
struct ggml_tensor * inp_pos_bucket; // I32 [n_batch|n_kv, n_batch]
|
||||
struct ggml_tensor * inp_embd_enc; // F32 [n_embd, n_outputs_enc]
|
||||
struct ggml_tensor * inp_KQ_mask_cross; // F32 [n_outputs_enc, n_batch]
|
||||
|
||||
// vision
|
||||
clip_context vctx;
|
||||
};
|
||||
|
||||
// TODO: make these methods of llama_context
|
||||
|
|
|
@ -375,6 +375,7 @@ namespace GGUFMeta {
|
|||
|
||||
template bool llama_model_loader::get_key<bool> (enum llm_kv kid, bool & result, bool required);
|
||||
template bool llama_model_loader::get_key<float> (enum llm_kv kid, float & result, bool required);
|
||||
template bool llama_model_loader::get_key<int32_t> (enum llm_kv kid, int32_t & result, bool required);
|
||||
template bool llama_model_loader::get_key<uint32_t> (enum llm_kv kid, uint32_t & result, bool required);
|
||||
template bool llama_model_loader::get_key<std::string>(enum llm_kv kid, std::string & result, bool required);
|
||||
|
||||
|
@ -439,6 +440,7 @@ namespace GGUFMeta {
|
|||
// TODO: this is not very clever - figure out something better
|
||||
template bool llama_model_loader::get_key_or_arr<std::array<int, 4>>(enum llm_kv kid, std::array<int, 4> & result, uint32_t n, bool required);
|
||||
template bool llama_model_loader::get_key_or_arr<std::array<uint32_t, 512>>(enum llm_kv kid, std::array<uint32_t, 512> & result, uint32_t n, bool required);
|
||||
template bool llama_model_loader::get_key_or_arr<std::array<float, 3>>(enum llm_kv kid, std::array<float, 3> & result, uint32_t n, bool required);
|
||||
|
||||
llama_model_loader::llama_model_loader(
|
||||
const std::string & fname,
|
||||
|
|
|
@ -1245,6 +1245,54 @@ void llama_model::load_hparams(llama_model_loader & ml) {
|
|||
}
|
||||
|
||||
hparams.rope_type = llama_model_rope_type(this);
|
||||
|
||||
// vision model
|
||||
auto & vparams = clip.hparams;
|
||||
std::string vision_type;
|
||||
ml.get_key(LLM_KV_VISION_TYPE, vision_type, false);
|
||||
if (vision_type == "clip-vit") {
|
||||
LLAMA_LOG_INFO("%s: loading clip-vit vision model\n", __func__);
|
||||
has_vision = true;
|
||||
ml.get_key(LLM_KV_VISION_IMAGE_SIZE, vparams.image_size, true);
|
||||
ml.get_key(LLM_KV_VISION_PATCH_SIZE, vparams.patch_size, true);
|
||||
ml.get_key_or_arr(LLM_KV_VISION_IMAGE_MEAN, vparams.image_mean, 3, true);
|
||||
ml.get_key_or_arr(LLM_KV_VISION_IMAGE_STD, vparams.image_std, 3, true);
|
||||
ml.get_key(LLM_KV_VISION_CLIP_EMBEDDING_LENGTH, vparams.hidden_size, true);
|
||||
ml.get_key(LLM_KV_VISION_CLIP_BLOCK_COUNT, vparams.n_layer, true);
|
||||
ml.get_key(LLM_KV_VISION_CLIP_FEED_FORWARD_LENGTH, vparams.n_intermediate, true);
|
||||
ml.get_key(LLM_KV_VISION_CLIP_HEAD_COUNT, vparams.n_head, true);
|
||||
ml.get_key(LLM_KV_VISION_CLIP_LAYERNORM_EPS, vparams.eps, true);
|
||||
ml.get_key(LLM_KV_VISION_CLIP_SELECT_LAYER, vparams.select_layer, true);
|
||||
{
|
||||
std::string name;
|
||||
ml.get_key(LLM_KV_VISION_CLIP_PROJECTOR_TYPE, name, true);
|
||||
vparams.proj_type = clip_projector_type_from_name(name);
|
||||
if (vparams.proj_type == CLIP_PROJECTOR_TYPE_UNKNOWN) {
|
||||
throw std::runtime_error(format("unsupported clip projector type: %s", name.c_str()));
|
||||
}
|
||||
}
|
||||
{
|
||||
std::string name;
|
||||
ml.get_key(LLM_KV_VISION_CLIP_PATCH_MERGE_TYPE, name, false);
|
||||
vparams.mm_patch_merge_type = mm_patch_merge_from_name(name);
|
||||
}
|
||||
{
|
||||
std::string arch;
|
||||
ml.get_key(LLM_KV_VISION_CLIP_ARCHITECTURE, arch, true);
|
||||
vparams.arch = vision_arch_from_string(arch);
|
||||
}
|
||||
} else if (!vision_type.empty()) {
|
||||
throw std::runtime_error(format("unsupported vision type: %s", vision_type.c_str()));
|
||||
}
|
||||
|
||||
// arch-specific CLIP hparams
|
||||
switch (vparams.arch) {
|
||||
case VISION_ARCH_LLAVA:
|
||||
{
|
||||
ml.get_key(LLM_KV_VISION_CLIP_MAX_POS_EMBD, vparams.max_pos_embd, true);
|
||||
} break;
|
||||
default: (void)0;
|
||||
}
|
||||
}
|
||||
|
||||
void llama_model::load_vocab(llama_model_loader & ml) {
|
||||
|
@ -3359,6 +3407,72 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
|
|||
}
|
||||
}
|
||||
|
||||
// load tensors for vision model
|
||||
auto & vparams = clip.hparams;
|
||||
if (has_vision) {
|
||||
const int64_t n_layer = vparams.n_layer;
|
||||
const int64_t n_embd = vparams.hidden_size;
|
||||
const int64_t n_ff = vparams.n_intermediate;
|
||||
const int64_t max_pos_embd = vparams.max_pos_embd;
|
||||
const int64_t n_channel = 3; // always RGB
|
||||
const int64_t patch_size = vparams.patch_size;
|
||||
const auto tn = VISION_TN(vparams.arch);
|
||||
|
||||
// clip is CPU-only for now
|
||||
clip.buft = ggml_backend_cpu_buffer_type();
|
||||
ggml_context * ctx_vision = ctx_map.at(clip.buft);
|
||||
clip.layers.resize(n_layer);
|
||||
|
||||
switch (vparams.arch) {
|
||||
case VISION_ARCH_LLAVA:
|
||||
{
|
||||
clip.mm_1_w = ml.create_tensor(ctx_vision, tn(VISION_TENSOR_MMPROJ, "weight", 1), {n_embd, n_ff});
|
||||
clip.mm_1_b = ml.create_tensor(ctx_vision, tn(VISION_TENSOR_MMPROJ, "bias" , 1), {n_ff});
|
||||
clip.mm_2_w = ml.create_tensor(ctx_vision, tn(VISION_TENSOR_MMPROJ, "weight", 2), {n_ff, n_ff});
|
||||
clip.mm_2_b = ml.create_tensor(ctx_vision, tn(VISION_TENSOR_MMPROJ, "bias" , 2), {n_ff});
|
||||
|
||||
clip.class_embedding = ml.create_tensor(ctx_vision, tn(VISION_TENSOR_ENC_EMBD_CLS ), {n_embd});
|
||||
clip.patch_embeddings = ml.create_tensor(ctx_vision, tn(VISION_TENSOR_ENC_EMBD_PATCH, "weight"), {patch_size, patch_size, n_channel, n_embd});
|
||||
clip.position_embeddings = ml.create_tensor(ctx_vision, tn(VISION_TENSOR_ENC_EMBD_POS, "weight"), {n_embd, max_pos_embd});
|
||||
|
||||
clip.pre_norm_w = ml.create_tensor(ctx_vision, tn(VISION_TENSOR_PRE_NORM, "weight"), {n_embd});
|
||||
clip.pre_norm_b = ml.create_tensor(ctx_vision, tn(VISION_TENSOR_PRE_NORM, "bias" ), {n_embd});
|
||||
clip.post_norm_w = ml.create_tensor(ctx_vision, tn(VISION_TENSOR_POST_NORM, "weight"), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
|
||||
clip.post_norm_b = ml.create_tensor(ctx_vision, tn(VISION_TENSOR_POST_NORM, "bias" ), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
|
||||
|
||||
for (int i = 0; i < n_layer; ++i) {
|
||||
auto & layer = clip.layers[i];
|
||||
|
||||
layer.k_w = ml.create_tensor(ctx_vision, tn(VISION_TENSOR_ENC_ATTN_K, "weight", i), {n_embd, n_embd});
|
||||
layer.k_b = ml.create_tensor(ctx_vision, tn(VISION_TENSOR_ENC_ATTN_K, "bias" , i), {n_embd});
|
||||
layer.v_w = ml.create_tensor(ctx_vision, tn(VISION_TENSOR_ENC_ATTN_V, "weight", i), {n_embd, n_embd});
|
||||
layer.v_b = ml.create_tensor(ctx_vision, tn(VISION_TENSOR_ENC_ATTN_V, "bias" , i), {n_embd});
|
||||
layer.q_w = ml.create_tensor(ctx_vision, tn(VISION_TENSOR_ENC_ATTN_Q, "weight", i), {n_embd, n_embd});
|
||||
layer.q_b = ml.create_tensor(ctx_vision, tn(VISION_TENSOR_ENC_ATTN_Q, "bias" , i), {n_embd});
|
||||
|
||||
layer.ffn_up_w = ml.create_tensor(ctx_vision, tn(VISION_TENSOR_ENC_FFN_UP, "weight", i), {n_embd, n_ff});
|
||||
layer.ffn_up_b = ml.create_tensor(ctx_vision, tn(VISION_TENSOR_ENC_FFN_UP, "bias" , i), {n_ff});
|
||||
layer.ffn_down_w = ml.create_tensor(ctx_vision, tn(VISION_TENSOR_ENC_FFN_DOWN, "weight", i), {n_ff, n_embd});
|
||||
layer.ffn_down_b = ml.create_tensor(ctx_vision, tn(VISION_TENSOR_ENC_FFN_DOWN, "bias" , i), {n_embd});
|
||||
|
||||
layer.norm_in_w = ml.create_tensor(ctx_vision, tn(VISION_TENSOR_ENC_INPUT_NORM, "weight", i), {n_embd});
|
||||
layer.norm_in_b = ml.create_tensor(ctx_vision, tn(VISION_TENSOR_ENC_INPUT_NORM, "bias" , i), {n_embd});
|
||||
layer.norm_out_w = ml.create_tensor(ctx_vision, tn(VISION_TENSOR_ENC_OUTPUT_NORM, "weight", i), {n_embd});
|
||||
layer.norm_out_b = ml.create_tensor(ctx_vision, tn(VISION_TENSOR_ENC_OUTPUT_NORM, "bias" , i), {n_embd});
|
||||
|
||||
layer.output_w = ml.create_tensor(ctx_vision, tn(VISION_TENSOR_ENC_OUTPUT, "weight", i), {n_embd, n_embd});
|
||||
layer.output_b = ml.create_tensor(ctx_vision, tn(VISION_TENSOR_ENC_OUTPUT, "bias" , i), {n_embd});
|
||||
}
|
||||
} break;
|
||||
default:
|
||||
throw std::runtime_error("unknown vision architecture");
|
||||
}
|
||||
|
||||
if (clip_n_mmproj_embd(clip) != hparams.n_embd) {
|
||||
std::runtime_error("model has vision, but n_mmproj_embd != n_embd");
|
||||
}
|
||||
}
|
||||
|
||||
ml.done_getting_tensors();
|
||||
|
||||
ml.init_mappings(true, use_mlock ? &pimpl->mlock_mmaps : nullptr);
|
||||
|
|
|
@ -4,6 +4,7 @@
|
|||
#include "llama-arch.h"
|
||||
#include "llama-hparams.h"
|
||||
#include "llama-vocab.h"
|
||||
#include "llama-vision.h"
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
|
@ -362,6 +363,10 @@ struct llama_model {
|
|||
|
||||
const struct ggml_tensor * get_tensor(const char * name) const;
|
||||
|
||||
// vision
|
||||
bool has_vision = false;
|
||||
clip_vision_model clip;
|
||||
|
||||
private:
|
||||
struct impl;
|
||||
std::unique_ptr<impl> pimpl;
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
#include "llama.h"
|
||||
#include "llama-vision.h"
|
||||
#include "llama-impl.h"
|
||||
#include "llama-context.h"
|
||||
|
||||
#include <string.h> // memcpy
|
||||
#include <limits>
|
||||
|
@ -43,15 +44,22 @@ struct clip_image_u8_batch {
|
|||
size_t size;
|
||||
};
|
||||
|
||||
static int clip_n_patches(const clip_context & ctx) {
|
||||
static int clip_n_patches_x(const clip_context & ctx) {
|
||||
auto & hparams = ctx.model->hparams;
|
||||
int n_patches = (hparams.image_size / hparams.patch_size) * (hparams.image_size / hparams.patch_size);
|
||||
return n_patches;
|
||||
return hparams.image_size / hparams.patch_size;
|
||||
}
|
||||
|
||||
int clip_n_mmproj_embd(const clip_context & ctx) {
|
||||
if (ctx.model->hparams.proj_type == CLIP_PROJECTOR_TYPE_MLP) {
|
||||
return ctx.model->mm_2_b->ne[0];
|
||||
static int clip_n_patches_y(const clip_context & ctx) {
|
||||
return clip_n_patches_x(ctx);
|
||||
}
|
||||
|
||||
static int clip_n_patches(const clip_context & ctx) {
|
||||
return clip_n_patches_x(ctx) * clip_n_patches_y(ctx);
|
||||
}
|
||||
|
||||
uint32_t clip_n_mmproj_embd(const clip_vision_model & clip_model) {
|
||||
if (clip_model.hparams.proj_type == CLIP_PROJECTOR_TYPE_MLP) {
|
||||
return clip_model.mm_2_b->ne[0];
|
||||
} else {
|
||||
GGML_ASSERT(false && "invalid proj type");
|
||||
}
|
||||
|
@ -242,11 +250,11 @@ static llama_vision_patches clip_image_preprocess(const clip_context & ctx, cons
|
|||
pad_to_square = false;
|
||||
}
|
||||
|
||||
llama_vision_patches output_imgs;
|
||||
output_imgs.px = clip_n_patches(ctx);
|
||||
output_imgs.py = clip_n_patches(ctx);
|
||||
output_imgs.n_px = params.image_size / output_imgs.px;
|
||||
output_imgs.n_py = params.image_size / output_imgs.py;
|
||||
llama_vision_patches output_patches;
|
||||
output_patches.n_px = clip_n_patches_x(ctx);
|
||||
output_patches.n_py = clip_n_patches_y(ctx);
|
||||
output_patches.px = params.patch_size;
|
||||
output_patches.py = params.patch_size;
|
||||
|
||||
// the logic below is to pad the shorter side to the longer side with a background color: rgb(122, 116, 104)
|
||||
// see https://github.com/haotian-liu/LLaVA/blob/e854a2bf85118c504f6f16bf5c3c7c92f8fa8c6b/llava/conversation.py#L113-L156
|
||||
|
@ -296,13 +304,13 @@ static llama_vision_patches clip_image_preprocess(const clip_context & ctx, cons
|
|||
bicubic_resize(img, image_original_resize, params.image_size, params.image_size); // in python this is "shortest_edge", but all CLIP are square
|
||||
patches.insert(patches.begin(), image_original_resize);
|
||||
// clip_image_f32_batch_init(patches.size());
|
||||
output_imgs.buf.resize(patches.size());
|
||||
output_patches.buf.resize(patches.size());
|
||||
int num = 0;
|
||||
for (auto & patch : patches) {
|
||||
normalize_image_u8_to_f32(patch, output_imgs.buf[num], params.image_mean, params.image_std);
|
||||
normalize_image_u8_to_f32(patch, output_patches.buf[num], params.image_mean, params.image_std);
|
||||
num++;
|
||||
}
|
||||
return output_imgs;
|
||||
return output_patches;
|
||||
} else {
|
||||
temp.nx = img.nx;
|
||||
temp.ny = img.ny;
|
||||
|
@ -367,10 +375,10 @@ static llama_vision_patches clip_image_preprocess(const clip_context & ctx, cons
|
|||
}
|
||||
}
|
||||
|
||||
output_imgs.buf.resize(1);
|
||||
output_imgs.buf[0] = std::move(res);
|
||||
output_patches.buf.resize(1);
|
||||
output_patches.buf[0] = std::move(res);
|
||||
|
||||
return output_imgs;
|
||||
return output_patches;
|
||||
}
|
||||
|
||||
static ggml_cgraph * clip_image_build_graph(clip_context & ctx, int batch_size, clip_image_size & image_size) {
|
||||
|
@ -556,14 +564,16 @@ static ggml_cgraph * clip_image_build_graph(clip_context & ctx, int batch_size,
|
|||
}
|
||||
}
|
||||
|
||||
embeddings = ggml_cont(ctx0, embeddings);
|
||||
|
||||
// build the graph
|
||||
ggml_build_forward_expand(gf, embeddings);
|
||||
ggml_free(ctx0);
|
||||
return gf;
|
||||
}
|
||||
|
||||
static int32_t clip_image_batch_encode(clip_context & ctx, const clip_image_f32_batch & imgs, std::vector<float> & output) {
|
||||
int batch_size = imgs.size();
|
||||
static int32_t clip_image_encode(clip_context & ctx, const llama_vision_patches & patches) {
|
||||
int batch_size = patches.buf.size();
|
||||
auto & model = *ctx.model;
|
||||
auto & hparams = ctx.model->hparams;
|
||||
|
||||
|
@ -595,15 +605,15 @@ static int32_t clip_image_batch_encode(clip_context & ctx, const clip_image_f32_
|
|||
float * data = (float *)malloc(ggml_nbytes(inp_raw));
|
||||
|
||||
for (int i = 0; i < batch_size; i++) {
|
||||
const int nx = imgs[i].nx;
|
||||
const int ny = imgs[i].ny;
|
||||
const int nx = patches.px * patches.n_px;
|
||||
const int ny = patches.py * patches.n_py;
|
||||
const int n = nx * ny;
|
||||
|
||||
for (int b = 0; b < batch_size; b++) {
|
||||
for (int k = 0; k < 3; k++) {
|
||||
for (int y = 0; y < ny; y++) {
|
||||
for (int x = 0; x < nx; x++) {
|
||||
data[(b * 3 * n) + k * n + y * nx + x] = imgs[b].buf[3 * (y * nx + x) + k];
|
||||
data[(b * 3 * n) + k * n + y * nx + x] = patches.buf[b][3 * (y * nx + x) + k];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -644,45 +654,71 @@ static int32_t clip_image_batch_encode(clip_context & ctx, const clip_image_f32_
|
|||
}
|
||||
|
||||
// compute
|
||||
ggml_backend_sched_graph_compute_async(ctx.sched, gf);
|
||||
ggml_backend_sched_graph_compute(ctx.sched, gf);
|
||||
|
||||
// the last node is the embedding tensor
|
||||
struct ggml_tensor * embeddings = ggml_graph_node(gf, -1);
|
||||
ggml_backend_t backend_embd = ggml_backend_sched_get_tensor_backend(ctx.sched, embeddings);
|
||||
struct ggml_tensor * output_node = ggml_graph_node(gf, -1);
|
||||
//LLAMA_LOG_INFO("%s: output tensor shape = %lld %lld %lld %lld\n", __func__, output->ne[0], output->ne[1], output->ne[2], output->ne[3]);
|
||||
|
||||
// copy the embeddings to the location passed by the user
|
||||
size_t out_nbytes = clip_n_patches(ctx)*clip_n_mmproj_embd(ctx)*sizeof(float);
|
||||
GGML_ASSERT(out_nbytes == ggml_nbytes(embeddings));
|
||||
output.resize(out_nbytes);
|
||||
ggml_backend_tensor_get_async(backend_embd, embeddings, output.data(), 0, ggml_nbytes(embeddings));
|
||||
|
||||
ggml_backend_sched_synchronize(ctx.sched);
|
||||
// copy output node to context
|
||||
if (ctx.ctx_ggml) {
|
||||
ggml_free(ctx.ctx_ggml);
|
||||
}
|
||||
ggml_init_params params = {
|
||||
/*.mem_size =*/ ggml_tensor_overhead(),
|
||||
/*.mem_buffer =*/ NULL,
|
||||
/*.no_alloc =*/ true,
|
||||
};
|
||||
ctx.ctx_ggml = ggml_init(params);
|
||||
ctx.output = ggml_dup_tensor(ctx.ctx_ggml, output_node);
|
||||
ggml_backend_alloc_ctx_tensors_from_buft(ctx.ctx_ggml, ctx.model->buft);
|
||||
ggml_backend_tensor_copy(output_node, ctx.output);
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
static int32_t clip_image_encode(clip_context & ctx, const clip_image_f32 & img, std::vector<float> & output) {
|
||||
clip_image_f32_batch imgs{img};
|
||||
return clip_image_batch_encode(ctx, imgs, output);
|
||||
////////////////////////////////////////////////////////////////////////////////////////
|
||||
// public API
|
||||
|
||||
struct llama_vision_bitmap * llama_vision_bitmap_init(uint32_t nx, uint32_t ny) {
|
||||
llama_vision_bitmap * bmp = new llama_vision_bitmap;
|
||||
bmp->nx = nx;
|
||||
bmp->ny = ny;
|
||||
bmp->data = (unsigned char *)malloc(3 * nx * ny);
|
||||
return bmp;
|
||||
}
|
||||
|
||||
static int32_t encode_image_with_clip(clip_context & ctx, const llama_img img, std::vector<float> & output_embd) {
|
||||
clip_image_u8 img_u8(img);
|
||||
clip_image_f32_batch img_res_v;
|
||||
auto & hparams = ctx.model->hparams;
|
||||
// bmp_export(img_u8, "test_inp.bmp");
|
||||
|
||||
if (!clip_image_preprocess(ctx, img_u8, img_res_v)) {
|
||||
LLAMA_LOG_ERROR("%s: unable to preprocess image\n", __func__);
|
||||
return -2;
|
||||
void llama_vision_bitmap_free(llama_vision_bitmap * bmp) {
|
||||
free(bmp->data);
|
||||
delete bmp;
|
||||
}
|
||||
|
||||
struct llama_vision_patches * llama_vision_patches_init(
|
||||
struct llama_context * ctx,
|
||||
llama_vision_bitmap * bmp) {
|
||||
clip_context & vctx = ctx->vctx;
|
||||
llama_vision_patches p = clip_image_preprocess(vctx, *bmp);
|
||||
return new llama_vision_patches(p);
|
||||
}
|
||||
|
||||
void llama_vision_patches_free(llama_vision_patches * p) {
|
||||
delete p;
|
||||
}
|
||||
|
||||
int32_t llama_vision_encode(struct llama_context * ctx, llama_vision_patches * p) {
|
||||
if (p->buf.empty()) {
|
||||
LLAMA_LOG_ERROR("%s: nothing to encode\n", __func__);
|
||||
return -1;
|
||||
}
|
||||
|
||||
clip_context & vctx = ctx->vctx;
|
||||
auto & hparams = vctx.model->hparams;
|
||||
switch (hparams.mm_patch_merge_type) {
|
||||
case MM_PATCH_MERGE_FLAT:
|
||||
{
|
||||
// flat / default llava-1.5 type embedding
|
||||
// n_output = clip_n_patches(ctx);
|
||||
int32_t encoded = clip_image_encode(ctx, img_res_v[0], output_embd);
|
||||
int32_t encoded = clip_image_encode(vctx, *p);
|
||||
if (encoded != 0) {
|
||||
LLAMA_LOG_ERROR("Unable to encode image\n");
|
||||
return encoded;
|
||||
|
@ -700,44 +736,8 @@ static int32_t encode_image_with_clip(clip_context & ctx, const llama_img img, s
|
|||
return 0;
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////
|
||||
// public API
|
||||
|
||||
int32_t llama_encode_vision_internal(clip_context & ctx, llama_batch_img * batch) {
|
||||
if (batch->n_imgs == 0) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
// TODO: batching is not working atm, should be fixed later
|
||||
const int n_embd = clip_n_mmproj_embd(ctx);
|
||||
const int n_tokens_per_img = clip_n_patches(ctx);
|
||||
const int n_pos = n_tokens_per_img*batch->n_imgs;
|
||||
|
||||
ctx.out_embd.resize(n_embd*n_pos);
|
||||
ctx.out_pos.resize(n_pos);
|
||||
|
||||
for (int i = 0; i < batch->n_imgs; i++) {
|
||||
std::vector<float> output_single;
|
||||
int32_t status = encode_image_with_clip(ctx, *batch->imgs[i], output_single);
|
||||
if (status != 0) {
|
||||
return status;
|
||||
}
|
||||
// copy output embeddings to result
|
||||
for (int k = 0; k < n_embd*n_tokens_per_img; k++) {
|
||||
ctx.out_embd[n_embd*n_tokens_per_img*i + k] = output_single[k];
|
||||
}
|
||||
// fill position for all output tokens
|
||||
for (int p = 0; p < n_tokens_per_img; p++) {
|
||||
ctx.out_pos[n_tokens_per_img*i + p] = batch->pos[i] + p;
|
||||
}
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
void llama_vision_clear_output(clip_context & ctx) {
|
||||
ctx.out_embd.clear();
|
||||
ctx.out_pos.clear();
|
||||
struct ggml_tensor * llama_vision_get_output_tensor(llama_context * ctx) {
|
||||
return ctx->vctx.output;
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
|
|
@ -2,15 +2,11 @@
|
|||
|
||||
#include "ggml.h"
|
||||
#include "llama.h"
|
||||
#include "llama-arch.h"
|
||||
|
||||
#include <vector>
|
||||
#include <array>
|
||||
|
||||
enum vision_arch {
|
||||
VISION_ARCH_UNKNOWN,
|
||||
VISION_ARCH_LLAVA,
|
||||
};
|
||||
|
||||
enum clip_projector_type {
|
||||
CLIP_PROJECTOR_TYPE_UNKNOWN,
|
||||
CLIP_PROJECTOR_TYPE_MLP,
|
||||
|
@ -50,72 +46,76 @@ struct clip_hparams {
|
|||
|
||||
struct clip_layer {
|
||||
// attention
|
||||
struct ggml_tensor * k_w = NULL;
|
||||
struct ggml_tensor * k_b = NULL;
|
||||
struct ggml_tensor * q_w = NULL;
|
||||
struct ggml_tensor * q_b = NULL;
|
||||
struct ggml_tensor * v_w = NULL;
|
||||
struct ggml_tensor * v_b = NULL;
|
||||
struct ggml_tensor * k_w = nullptr;
|
||||
struct ggml_tensor * k_b = nullptr;
|
||||
struct ggml_tensor * q_w = nullptr;
|
||||
struct ggml_tensor * q_b = nullptr;
|
||||
struct ggml_tensor * v_w = nullptr;
|
||||
struct ggml_tensor * v_b = nullptr;
|
||||
|
||||
struct ggml_tensor * output_w = NULL;
|
||||
struct ggml_tensor * output_b = NULL;
|
||||
struct ggml_tensor * output_w = nullptr;
|
||||
struct ggml_tensor * output_b = nullptr;
|
||||
|
||||
// layernorm 1
|
||||
struct ggml_tensor * norm_in_w = NULL;
|
||||
struct ggml_tensor * norm_in_b = NULL;
|
||||
struct ggml_tensor * norm_in_w = nullptr;
|
||||
struct ggml_tensor * norm_in_b = nullptr;
|
||||
|
||||
// ff
|
||||
struct ggml_tensor * ffn_up_w = NULL;
|
||||
struct ggml_tensor * ffn_up_b = NULL;
|
||||
struct ggml_tensor * ffn_up_w = nullptr;
|
||||
struct ggml_tensor * ffn_up_b = nullptr;
|
||||
|
||||
struct ggml_tensor * ffn_down_w = NULL;
|
||||
struct ggml_tensor * ffn_down_b = NULL;
|
||||
struct ggml_tensor * ffn_down_w = nullptr;
|
||||
struct ggml_tensor * ffn_down_b = nullptr;
|
||||
|
||||
// layernorm 2
|
||||
struct ggml_tensor * norm_out_w = NULL;
|
||||
struct ggml_tensor * norm_out_b = NULL;
|
||||
struct ggml_tensor * norm_out_w = nullptr;
|
||||
struct ggml_tensor * norm_out_b = nullptr;
|
||||
};
|
||||
|
||||
struct clip_vision_model {
|
||||
struct clip_hparams hparams;
|
||||
ggml_backend_buffer_type_t buft;
|
||||
|
||||
// embeddings
|
||||
struct ggml_tensor * class_embedding = NULL;
|
||||
struct ggml_tensor * patch_embeddings = NULL;
|
||||
struct ggml_tensor * patch_bias = NULL;
|
||||
struct ggml_tensor * position_embeddings = NULL;
|
||||
struct ggml_tensor * class_embedding = nullptr;
|
||||
struct ggml_tensor * patch_embeddings = nullptr;
|
||||
struct ggml_tensor * patch_bias = nullptr;
|
||||
struct ggml_tensor * position_embeddings = nullptr;
|
||||
|
||||
struct ggml_tensor * pre_norm_w = NULL;
|
||||
struct ggml_tensor * pre_norm_b = NULL;
|
||||
struct ggml_tensor * pre_norm_w = nullptr;
|
||||
struct ggml_tensor * pre_norm_b = nullptr;
|
||||
|
||||
std::vector<clip_layer> layers;
|
||||
|
||||
struct ggml_tensor * post_norm_w = NULL;
|
||||
struct ggml_tensor * post_norm_b = NULL;
|
||||
struct ggml_tensor * post_norm_w = nullptr;
|
||||
struct ggml_tensor * post_norm_b = nullptr;
|
||||
|
||||
struct ggml_tensor * projection = NULL;
|
||||
struct ggml_tensor * projection = nullptr;
|
||||
|
||||
// LLaVA projection
|
||||
struct ggml_tensor * mm_1_w = NULL;
|
||||
struct ggml_tensor * mm_1_b = NULL;
|
||||
struct ggml_tensor * mm_2_w = NULL;
|
||||
struct ggml_tensor * mm_2_b = NULL;
|
||||
struct ggml_tensor * mm_1_w = nullptr;
|
||||
struct ggml_tensor * mm_1_b = nullptr;
|
||||
struct ggml_tensor * mm_2_w = nullptr;
|
||||
struct ggml_tensor * mm_2_b = nullptr;
|
||||
|
||||
struct ggml_tensor * image_newline = NULL;
|
||||
struct ggml_tensor * image_newline = nullptr;
|
||||
};
|
||||
|
||||
struct clip_context {
|
||||
// memory buffers used to evaluate the model
|
||||
std::vector<uint8_t> buf_compute_meta;
|
||||
ggml_backend_sched_t sched = nullptr;
|
||||
struct ggml_context * ctx_ggml = nullptr;
|
||||
|
||||
const clip_vision_model * model;
|
||||
|
||||
// temporary output data, to be picked up by llama_decode()
|
||||
std::vector<float> out_embd; // size == n_tokens * n_embd
|
||||
std::vector<llama_pos> out_pos; // position of each token
|
||||
struct ggml_tensor * output;
|
||||
};
|
||||
|
||||
// for now, this only contains:
|
||||
// - the instruction for ggml_conv_2d to break the image into patches
|
||||
// - the pre-processed image data in f32
|
||||
struct llama_vision_patches {
|
||||
uint32_t px; // size of patch
|
||||
uint32_t py; // size of patch
|
||||
|
@ -126,7 +126,7 @@ struct llama_vision_patches {
|
|||
std::vector<std::vector<float>> buf; // preprocessed image data
|
||||
};
|
||||
|
||||
mm_patch_merge mm_patch_merge_from_name(std::string & name) {
|
||||
inline mm_patch_merge mm_patch_merge_from_name(std::string & name) {
|
||||
if (name == "flat") {
|
||||
return MM_PATCH_MERGE_FLAT;
|
||||
} else if (name == "spatial_unpad") {
|
||||
|
@ -135,17 +135,14 @@ mm_patch_merge mm_patch_merge_from_name(std::string & name) {
|
|||
return MM_PATCH_MERGE_UNKNOWN;
|
||||
}
|
||||
|
||||
clip_projector_type clip_projector_type_from_name(std::string & name) {
|
||||
inline clip_projector_type clip_projector_type_from_name(std::string & name) {
|
||||
if (name == "mlp") {
|
||||
return CLIP_PROJECTOR_TYPE_MLP;
|
||||
}
|
||||
return CLIP_PROJECTOR_TYPE_UNKNOWN;
|
||||
}
|
||||
|
||||
llama_vision_patches * llama_vision_patches_init(llama_vision_bitmap * bmp);
|
||||
void llama_vision_patches_free(llama_vision_patches * p);
|
||||
// only for sanity check: must be equal to n_embd of language model
|
||||
uint32_t clip_n_mmproj_embd(const clip_vision_model & clip_model);
|
||||
|
||||
int32_t llama_vision_encode_impl(clip_context & ctx, llama_vision_patches * p);
|
||||
|
||||
// dimension of the output embeddings, must be equal to n_embd of language model
|
||||
int clip_n_mmproj_embd(const clip_context & ctx);
|
||||
struct ggml_tensor * llama_vision_get_output_tensor(llama_context * ctx);
|
||||
|
|
|
@ -138,6 +138,9 @@ static struct ggml_tensor * llm_build_inp_embd(
|
|||
), scale);
|
||||
inpL = ggml_add(ctx, inpL, inpL_delta);
|
||||
}
|
||||
} else if (ubatch.embd_tensor) {
|
||||
inpL = ubatch.embd_tensor;
|
||||
ggml_set_input(ubatch.embd_tensor);
|
||||
} else {
|
||||
lctx.inp_embd = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, ubatch.n_tokens);
|
||||
inpL = lctx.inp_embd;
|
||||
|
@ -8466,7 +8469,9 @@ static int llama_decode_impl(
|
|||
const auto & hparams = model.hparams;
|
||||
const auto & cparams = lctx.cparams;
|
||||
|
||||
GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT
|
||||
GGML_ASSERT((batch.token && !batch.embd && !batch.embd_tensor)
|
||||
|| (!batch.token && batch.embd && !batch.embd_tensor)
|
||||
|| (!batch.token && !batch.embd && batch.embd_tensor));
|
||||
|
||||
if (batch.token) {
|
||||
for (uint32_t i = 0; i < n_tokens_all; ++i) {
|
||||
|
@ -9232,7 +9237,7 @@ static void llama_kv_cache_update_impl(struct llama_context & lctx) {
|
|||
uint32_t n_seqs = 1; // TODO: worst-case number of sequences
|
||||
uint32_t n_tokens = std::min(lctx.cparams.n_ctx, lctx.cparams.n_ubatch);
|
||||
llama_token token = lctx.model.vocab.token_bos(); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph
|
||||
llama_ubatch ubatch = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr};
|
||||
llama_ubatch ubatch = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr};
|
||||
ggml_cgraph * gf = llama_build_graph(lctx, ubatch, true);
|
||||
|
||||
// initialize scheduler with the worst-case graph
|
||||
|
@ -9785,7 +9790,7 @@ struct llama_context * llama_init_from_model(
|
|||
uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
|
||||
llama_token token = ctx->model.vocab.token_bos(); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph
|
||||
|
||||
llama_ubatch ubatch_pp = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr};
|
||||
llama_ubatch ubatch_pp = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr};
|
||||
ggml_cgraph * gf_pp = llama_build_graph(*ctx, ubatch_pp, true);
|
||||
|
||||
// reserve pp graph first so that buffers are only allocated once
|
||||
|
@ -9794,7 +9799,7 @@ struct llama_context * llama_init_from_model(
|
|||
int n_nodes_pp = ggml_graph_n_nodes(gf_pp);
|
||||
|
||||
// reserve with tg graph to get the number of splits and nodes
|
||||
llama_ubatch ubatch_tg = { true, 1, 1, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr};
|
||||
llama_ubatch ubatch_tg = { true, 1, 1, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr};
|
||||
ggml_cgraph * gf_tg = llama_build_graph(*ctx, ubatch_tg, true);
|
||||
ggml_backend_sched_reserve(ctx->sched.get(), gf_tg);
|
||||
int n_splits_tg = ggml_backend_sched_get_n_splits(ctx->sched.get());
|
||||
|
@ -9832,6 +9837,13 @@ struct llama_context * llama_init_from_model(
|
|||
}
|
||||
}
|
||||
|
||||
if (model->has_vision) {
|
||||
ctx->vctx.model = &model->clip;
|
||||
ctx->vctx.sched = ctx->sched.get();
|
||||
const size_t max_nodes = 1024;
|
||||
ctx->vctx.buf_compute_meta.resize(ggml_tensor_overhead()*max_nodes + ggml_graph_overhead_custom(max_nodes, false));
|
||||
}
|
||||
|
||||
return ctx;
|
||||
}
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue