clip: enable CUDA backend
This commit is contained in:
parent
2568a4bf54
commit
08e7afacf7
2 changed files with 98 additions and 76 deletions
|
@ -34,3 +34,7 @@ add_executable(llava-cli llava-cli.cpp)
|
||||||
install(TARGETS llava-cli RUNTIME)
|
install(TARGETS llava-cli RUNTIME)
|
||||||
target_link_libraries(llava-cli PRIVATE common llama llava ${CMAKE_THREAD_LIBS_INIT})
|
target_link_libraries(llava-cli PRIVATE common llama llava ${CMAKE_THREAD_LIBS_INIT})
|
||||||
target_compile_features(llava PRIVATE cxx_std_11)
|
target_compile_features(llava PRIVATE cxx_std_11)
|
||||||
|
|
||||||
|
if(LLAMA_CUBLAS)
|
||||||
|
add_definitions(-DCLIP_USE_CUBLAS)
|
||||||
|
endif()
|
||||||
|
|
|
@ -16,6 +16,11 @@
|
||||||
#include "clip.h"
|
#include "clip.h"
|
||||||
#include "ggml.h"
|
#include "ggml.h"
|
||||||
#include "ggml-alloc.h"
|
#include "ggml-alloc.h"
|
||||||
|
#include "ggml-backend.h"
|
||||||
|
|
||||||
|
#ifdef CLIP_USE_CUBLAS
|
||||||
|
#include "ggml-cuda.h"
|
||||||
|
#endif
|
||||||
|
|
||||||
#define STB_IMAGE_IMPLEMENTATION
|
#define STB_IMAGE_IMPLEMENTATION
|
||||||
#include "stb_image.h"
|
#include "stb_image.h"
|
||||||
|
@ -196,20 +201,6 @@ struct clip_vision_model {
|
||||||
struct ggml_tensor * mm_2_b;
|
struct ggml_tensor * mm_2_b;
|
||||||
};
|
};
|
||||||
|
|
||||||
// Replacement for std::vector<uint8_t> that doesn't require zero-initialization.
|
|
||||||
struct clip_buffer {
|
|
||||||
uint8_t * data = NULL;
|
|
||||||
size_t size = 0;
|
|
||||||
|
|
||||||
void resize(size_t size) {
|
|
||||||
delete[] data;
|
|
||||||
data = new uint8_t[size];
|
|
||||||
this->size = size;
|
|
||||||
}
|
|
||||||
|
|
||||||
~clip_buffer() { delete[] data; }
|
|
||||||
};
|
|
||||||
|
|
||||||
struct clip_ctx {
|
struct clip_ctx {
|
||||||
bool has_text_encoder = false;
|
bool has_text_encoder = false;
|
||||||
bool has_vision_encoder = false;
|
bool has_vision_encoder = false;
|
||||||
|
@ -223,9 +214,10 @@ struct clip_ctx {
|
||||||
struct gguf_context * ctx_gguf;
|
struct gguf_context * ctx_gguf;
|
||||||
|
|
||||||
// memory buffers to evaluate the model
|
// memory buffers to evaluate the model
|
||||||
clip_buffer buf_compute;
|
ggml_backend_buffer_t params_buffer = NULL;
|
||||||
clip_buffer buf_alloc;
|
ggml_backend_buffer_t compute_buffer = NULL;
|
||||||
ggml_allocr * alloc = NULL;
|
ggml_backend_t backend = NULL;
|
||||||
|
ggml_allocr * compute_alloc = NULL;
|
||||||
};
|
};
|
||||||
|
|
||||||
static ggml_cgraph * clip_image_build_graph(const clip_ctx * ctx, const clip_image_f32_batch * imgs) {
|
static ggml_cgraph * clip_image_build_graph(const clip_ctx * ctx, const clip_image_f32_batch * imgs) {
|
||||||
|
@ -252,25 +244,20 @@ static ggml_cgraph * clip_image_build_graph(const clip_ctx * ctx, const clip_ima
|
||||||
if(ctx->has_llava_projector) {
|
if(ctx->has_llava_projector) {
|
||||||
GGML_ASSERT(batch_size == 1);
|
GGML_ASSERT(batch_size == 1);
|
||||||
}
|
}
|
||||||
|
|
||||||
const auto & buf_compute = ctx->buf_compute;
|
|
||||||
|
|
||||||
struct ggml_init_params params = {
|
struct ggml_init_params params = {
|
||||||
/*.mem_size =*/ buf_compute.size,
|
/*.mem_size =*/ GGML_DEFAULT_GRAPH_SIZE * ggml_tensor_overhead() + ggml_graph_overhead(),
|
||||||
/*.mem_buffer =*/ buf_compute.data,
|
/*.mem_buffer =*/ NULL,
|
||||||
/*.no_alloc =*/ false,
|
/*.no_alloc =*/ true,
|
||||||
};
|
};
|
||||||
|
|
||||||
params.no_alloc = true;
|
|
||||||
|
|
||||||
struct ggml_context * ctx0 = ggml_init(params);
|
struct ggml_context * ctx0 = ggml_init(params);
|
||||||
struct ggml_cgraph * gf = ggml_new_graph(ctx0);
|
struct ggml_cgraph * gf = ggml_new_graph(ctx0);
|
||||||
|
|
||||||
struct ggml_tensor * inp_raw = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, image_size, image_size, 3, batch_size);
|
struct ggml_tensor * inp_raw = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, image_size, image_size, 3, batch_size);
|
||||||
ggml_allocr_alloc(ctx->alloc, inp_raw);
|
ggml_allocr_alloc(ctx->compute_alloc, inp_raw);
|
||||||
|
|
||||||
if (!ggml_allocr_is_measure(ctx->alloc)) {
|
if (!ggml_allocr_is_measure(ctx->compute_alloc)) {
|
||||||
float * data = (float *)ggml_get_data(inp_raw);
|
float * data = (float *)malloc(ggml_nbytes(inp_raw));
|
||||||
|
|
||||||
for (size_t i = 0; i < imgs->size; i++) {
|
for (size_t i = 0; i < imgs->size; i++) {
|
||||||
const int nx = imgs->data[i].nx;
|
const int nx = imgs->data[i].nx;
|
||||||
|
@ -289,6 +276,8 @@ static ggml_cgraph * clip_image_build_graph(const clip_ctx * ctx, const clip_ima
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
ggml_backend_tensor_set(inp_raw, data, 0, ggml_nbytes(inp_raw));
|
||||||
|
free(data);
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ggml_tensor * inp = ggml_conv_2d(ctx0, model.patch_embeddings, inp_raw, patch_size, patch_size, 0, 0, 1, 1);
|
struct ggml_tensor * inp = ggml_conv_2d(ctx0, model.patch_embeddings, inp_raw, patch_size, patch_size, 0, 0, 1, 1);
|
||||||
|
@ -298,13 +287,16 @@ static ggml_cgraph * clip_image_build_graph(const clip_ctx * ctx, const clip_ima
|
||||||
|
|
||||||
// concat class_embeddings and patch_embeddings
|
// concat class_embeddings and patch_embeddings
|
||||||
struct ggml_tensor * embeddings = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, hidden_size, num_positions, batch_size);
|
struct ggml_tensor * embeddings = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, hidden_size, num_positions, batch_size);
|
||||||
ggml_allocr_alloc(ctx->alloc, embeddings);
|
ggml_allocr_alloc(ctx->compute_alloc, embeddings);
|
||||||
if (!ggml_allocr_is_measure(ctx->alloc)) {
|
if (!ggml_allocr_is_measure(ctx->compute_alloc)) {
|
||||||
ggml_set_zero(embeddings);
|
void* zero_mem = malloc(ggml_nbytes(embeddings));
|
||||||
|
memset(zero_mem, 0, ggml_nbytes(embeddings));
|
||||||
|
ggml_backend_tensor_set(embeddings, zero_mem, 0, ggml_nbytes(embeddings));
|
||||||
|
free(zero_mem);
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ggml_tensor * temp = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, hidden_size, 1, batch_size);
|
struct ggml_tensor * temp = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, hidden_size, 1, batch_size);
|
||||||
ggml_allocr_alloc(ctx->alloc, temp);
|
ggml_allocr_alloc(ctx->compute_alloc, temp);
|
||||||
|
|
||||||
embeddings = ggml_acc(ctx0, embeddings, ggml_repeat(ctx0, model.class_embedding, temp), embeddings->nb[1],
|
embeddings = ggml_acc(ctx0, embeddings, ggml_repeat(ctx0, model.class_embedding, temp), embeddings->nb[1],
|
||||||
embeddings->nb[2], embeddings->nb[3], 0);
|
embeddings->nb[2], embeddings->nb[3], 0);
|
||||||
|
@ -312,11 +304,14 @@ static ggml_cgraph * clip_image_build_graph(const clip_ctx * ctx, const clip_ima
|
||||||
ggml_acc(ctx0, embeddings, inp, embeddings->nb[1], embeddings->nb[2], embeddings->nb[3], model.class_embedding->nb[1]);
|
ggml_acc(ctx0, embeddings, inp, embeddings->nb[1], embeddings->nb[2], embeddings->nb[3], model.class_embedding->nb[1]);
|
||||||
|
|
||||||
struct ggml_tensor * positions = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, num_positions);
|
struct ggml_tensor * positions = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, num_positions);
|
||||||
ggml_allocr_alloc(ctx->alloc, positions);
|
ggml_allocr_alloc(ctx->compute_alloc, positions);
|
||||||
if (!ggml_allocr_is_measure(ctx->alloc)) {
|
if (!ggml_allocr_is_measure(ctx->compute_alloc)) {
|
||||||
|
int* positions_data = (int*)malloc(ggml_nbytes(positions));
|
||||||
for (int i = 0; i < num_positions; i++) {
|
for (int i = 0; i < num_positions; i++) {
|
||||||
ggml_set_i32_1d(positions, i, i);
|
positions_data[i] = i;
|
||||||
}
|
}
|
||||||
|
ggml_backend_tensor_set(positions, positions_data, 0, ggml_nbytes(positions));
|
||||||
|
free(positions_data);
|
||||||
}
|
}
|
||||||
|
|
||||||
embeddings =
|
embeddings =
|
||||||
|
@ -331,9 +326,11 @@ static ggml_cgraph * clip_image_build_graph(const clip_ctx * ctx, const clip_ima
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ggml_tensor * KQ_scale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1);
|
struct ggml_tensor * KQ_scale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1);
|
||||||
ggml_allocr_alloc(ctx->alloc, KQ_scale);
|
ggml_allocr_alloc(ctx->compute_alloc, KQ_scale);
|
||||||
if (!ggml_allocr_is_measure(ctx->alloc)) {
|
if (!ggml_allocr_is_measure(ctx->compute_alloc)) {
|
||||||
ggml_set_f32(KQ_scale, 1.0f / sqrt((float)d_head));
|
float scale = 1.0f / sqrt((float)d_head);
|
||||||
|
ggml_backend_tensor_set(KQ_scale, &scale, 0, ggml_nbytes(KQ_scale));
|
||||||
|
printf("alloc scale\n");
|
||||||
}
|
}
|
||||||
|
|
||||||
// loop over layers
|
// loop over layers
|
||||||
|
@ -423,11 +420,15 @@ static ggml_cgraph * clip_image_build_graph(const clip_ctx * ctx, const clip_ima
|
||||||
embeddings = ggml_reshape_2d(ctx0, embeddings, embeddings->ne[0], embeddings->ne[1]);
|
embeddings = ggml_reshape_2d(ctx0, embeddings, embeddings->ne[0], embeddings->ne[1]);
|
||||||
|
|
||||||
struct ggml_tensor * patches = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, num_patches);
|
struct ggml_tensor * patches = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, num_patches);
|
||||||
ggml_allocr_alloc(ctx->alloc, patches);
|
ggml_allocr_alloc(ctx->compute_alloc, patches);
|
||||||
if (!ggml_allocr_is_measure(ctx->alloc)) {
|
if (!ggml_allocr_is_measure(ctx->compute_alloc)) {
|
||||||
for (int i = 0; i < num_patches; ++i) {
|
int* patches_data = (int*)malloc(ggml_nbytes(patches));
|
||||||
ggml_set_i32_1d(patches, i, i+1);
|
for (int i = 0; i < num_positions; i++) {
|
||||||
|
patches_data[i] = i+1;
|
||||||
}
|
}
|
||||||
|
ggml_backend_tensor_set(patches, patches_data, 0, ggml_nbytes(patches));
|
||||||
|
free(patches_data);
|
||||||
|
printf("patches");
|
||||||
}
|
}
|
||||||
|
|
||||||
embeddings = ggml_get_rows(ctx0, embeddings, patches);
|
embeddings = ggml_get_rows(ctx0, embeddings, patches);
|
||||||
|
@ -485,7 +486,7 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) {
|
||||||
printf("%s: ftype: %s\n", __func__, ftype_str.c_str());
|
printf("%s: ftype: %s\n", __func__, ftype_str.c_str());
|
||||||
printf("\n");
|
printf("\n");
|
||||||
}
|
}
|
||||||
|
const int n_tensors = gguf_get_n_tensors(ctx);
|
||||||
// kv
|
// kv
|
||||||
if (verbosity >= 3) {
|
if (verbosity >= 3) {
|
||||||
const int n_kv = gguf_get_n_kv(ctx);
|
const int n_kv = gguf_get_n_kv(ctx);
|
||||||
|
@ -499,19 +500,15 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// data
|
// data
|
||||||
size_t ctx_size = 0;
|
size_t buffer_size = 0;
|
||||||
{
|
{
|
||||||
const int n_tensors = gguf_get_n_tensors(ctx);
|
|
||||||
|
|
||||||
for (int i = 0; i < n_tensors; ++i) {
|
for (int i = 0; i < n_tensors; ++i) {
|
||||||
const char * name = gguf_get_tensor_name(ctx, i);
|
const char * name = gguf_get_tensor_name(ctx, i);
|
||||||
const size_t offset = gguf_get_tensor_offset(ctx, i);
|
const size_t offset = gguf_get_tensor_offset(ctx, i);
|
||||||
|
|
||||||
struct ggml_tensor * cur = ggml_get_tensor(meta, name);
|
struct ggml_tensor * cur = ggml_get_tensor(meta, name);
|
||||||
ctx_size += sizeof(struct ggml_tensor) + GGML_OBJECT_SIZE;
|
|
||||||
size_t tensor_size = ggml_nbytes(cur);
|
size_t tensor_size = ggml_nbytes(cur);
|
||||||
size_t padded_size = ggml_nbytes_pad(cur);
|
size_t padded_size = ggml_nbytes_pad(cur);
|
||||||
ctx_size += padded_size;
|
buffer_size += padded_size;
|
||||||
if (verbosity >= 3) {
|
if (verbosity >= 3) {
|
||||||
printf("%s: tensor[%d]: n_dims = %d, name = %s, tensor_size=%zu, padded_size=%zu, offset=%zu\n", __func__, i,
|
printf("%s: tensor[%d]: n_dims = %d, name = %s, tensor_size=%zu, padded_size=%zu, offset=%zu\n", __func__, i,
|
||||||
cur->n_dims, cur->name, tensor_size, padded_size, offset);
|
cur->n_dims, cur->name, tensor_size, padded_size, offset);
|
||||||
|
@ -520,6 +517,15 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) {
|
||||||
}
|
}
|
||||||
|
|
||||||
clip_ctx * new_clip = new clip_ctx;
|
clip_ctx * new_clip = new clip_ctx;
|
||||||
|
#ifdef CLIP_USE_CUBLAS
|
||||||
|
new_clip->backend = ggml_backend_cuda_init();
|
||||||
|
printf("CLIP using CUDA backend\n");
|
||||||
|
#endif
|
||||||
|
|
||||||
|
if(!new_clip->backend) {
|
||||||
|
new_clip->backend = ggml_backend_cpu_init();
|
||||||
|
printf("CLIP using CPU backend\n");
|
||||||
|
}
|
||||||
|
|
||||||
// model size and capabilities
|
// model size and capabilities
|
||||||
{
|
{
|
||||||
|
@ -545,17 +551,20 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) {
|
||||||
printf("%s: text_encoder: %d\n", __func__, new_clip->has_text_encoder);
|
printf("%s: text_encoder: %d\n", __func__, new_clip->has_text_encoder);
|
||||||
printf("%s: vision_encoder: %d\n", __func__, new_clip->has_vision_encoder);
|
printf("%s: vision_encoder: %d\n", __func__, new_clip->has_vision_encoder);
|
||||||
printf("%s: llava_projector: %d\n", __func__, new_clip->has_llava_projector);
|
printf("%s: llava_projector: %d\n", __func__, new_clip->has_llava_projector);
|
||||||
printf("%s: model size: %.2f MB\n", __func__, (ctx_size / 1024.0 / 1024.0));
|
printf("%s: model size: %.2f MB\n", __func__, buffer_size / 1024.0 / 1024.0);
|
||||||
printf("%s: metadata size: %.2f MB\n", __func__, ggml_get_mem_size(meta) / 1024.0 / 1024.0);
|
printf("%s: metadata size: %.2f MB\n", __func__, ggml_get_mem_size(meta) / 1024.0 / 1024.0);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
printf("%s: params backend buffer size = % 6.2f MB (%i tensors)\n", __func__, buffer_size / (1024.0 * 1024.0), n_tensors);
|
||||||
|
|
||||||
// load tensors
|
// load tensors
|
||||||
{
|
{
|
||||||
|
std::vector<uint8_t> read_buf;
|
||||||
struct ggml_init_params params = {
|
struct ggml_init_params params = {
|
||||||
/*.mem_size =*/ ctx_size,
|
/*.mem_size =*/ (n_tensors + 1) * ggml_tensor_overhead(),
|
||||||
/*.mem_buffer =*/ NULL,
|
/*.mem_buffer =*/ NULL,
|
||||||
/*.no_alloc =*/ false,
|
/*.no_alloc =*/ true,
|
||||||
};
|
};
|
||||||
|
|
||||||
new_clip->ctx = ggml_init(params);
|
new_clip->ctx = ggml_init(params);
|
||||||
|
@ -572,13 +581,21 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
const int n_tensors = gguf_get_n_tensors(ctx);
|
// add tensors to context
|
||||||
for (int i = 0; i < n_tensors; ++i) {
|
for (int i = 0; i < n_tensors; ++i) {
|
||||||
const char * name = gguf_get_tensor_name(ctx, i);
|
const char * name = gguf_get_tensor_name(ctx, i);
|
||||||
struct ggml_tensor * t = ggml_get_tensor(meta, name);
|
struct ggml_tensor * t = ggml_get_tensor(meta, name);
|
||||||
struct ggml_tensor * cur = ggml_dup_tensor(new_clip->ctx, t);
|
struct ggml_tensor * cur = ggml_dup_tensor(new_clip->ctx, t);
|
||||||
ggml_set_name(cur, name);
|
ggml_set_name(cur, name);
|
||||||
|
}
|
||||||
|
|
||||||
|
// alloc memory and offload data
|
||||||
|
new_clip->params_buffer = ggml_backend_alloc_buffer(new_clip->backend, buffer_size);
|
||||||
|
ggml_allocr* alloc = ggml_allocr_new_from_buffer(new_clip->params_buffer);
|
||||||
|
for (int i = 0; i < n_tensors; ++i) {
|
||||||
|
const char * name = gguf_get_tensor_name(ctx, i);
|
||||||
|
struct ggml_tensor * cur = ggml_get_tensor(new_clip->ctx, name);
|
||||||
|
ggml_allocr_alloc(alloc, cur);
|
||||||
const size_t offset = gguf_get_data_offset(ctx) + gguf_get_tensor_offset(ctx, i);
|
const size_t offset = gguf_get_data_offset(ctx) + gguf_get_tensor_offset(ctx, i);
|
||||||
fin.seekg(offset, std::ios::beg);
|
fin.seekg(offset, std::ios::beg);
|
||||||
if (!fin) {
|
if (!fin) {
|
||||||
|
@ -586,10 +603,18 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) {
|
||||||
clip_free(new_clip);
|
clip_free(new_clip);
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
int num_bytes = ggml_nbytes(cur);
|
||||||
fin.read(reinterpret_cast<char *>(cur->data), ggml_nbytes(t));
|
if (ggml_backend_is_cpu(new_clip->backend)) {
|
||||||
|
// for the CPU and Metal backend, we can read directly into the tensor
|
||||||
|
fin.read(reinterpret_cast<char *>(cur->data), num_bytes);
|
||||||
|
} else {
|
||||||
|
// read into a temporary buffer first, then copy to device memory
|
||||||
|
read_buf.resize(num_bytes);
|
||||||
|
fin.read(reinterpret_cast<char *>(read_buf.data()), num_bytes);
|
||||||
|
ggml_backend_tensor_set(cur, read_buf.data(), 0, num_bytes);
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
ggml_allocr_free(alloc);
|
||||||
fin.close();
|
fin.close();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -663,18 +688,17 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) {
|
||||||
|
|
||||||
// measure mem requirement and allocate
|
// measure mem requirement and allocate
|
||||||
{
|
{
|
||||||
|
new_clip->compute_alloc = ggml_allocr_new_measure_from_backend(new_clip->backend);
|
||||||
static const size_t tensor_alignment = 32;
|
static const size_t tensor_alignment = 32;
|
||||||
new_clip->buf_compute.resize(ggml_tensor_overhead()*GGML_DEFAULT_GRAPH_SIZE + ggml_graph_overhead());
|
|
||||||
new_clip->alloc = ggml_allocr_new_measure(tensor_alignment);
|
|
||||||
clip_image_f32_batch batch;
|
clip_image_f32_batch batch;
|
||||||
batch.size = 1;
|
batch.size = 1;
|
||||||
ggml_cgraph * gf = clip_image_build_graph(new_clip, &batch);
|
ggml_cgraph * gf = clip_image_build_graph(new_clip, &batch);
|
||||||
size_t alloc_size = ggml_allocr_alloc_graph(new_clip->alloc, gf) + tensor_alignment;
|
size_t compute_memory_buffer_size = ggml_allocr_alloc_graph(new_clip->compute_alloc, gf);
|
||||||
ggml_allocr_free(new_clip->alloc);
|
ggml_allocr_free(new_clip->compute_alloc);
|
||||||
new_clip->buf_alloc.resize(alloc_size);
|
new_clip->compute_buffer = ggml_backend_alloc_buffer(new_clip->backend, compute_memory_buffer_size);
|
||||||
new_clip->alloc = ggml_allocr_new(new_clip->buf_alloc.data, new_clip->buf_alloc.size, tensor_alignment);
|
new_clip->compute_alloc = ggml_allocr_new_from_buffer(new_clip->compute_buffer);
|
||||||
|
|
||||||
printf("%s: total allocated memory: %.2f MB\n", __func__, (new_clip->buf_compute.size + alloc_size)/1024.0/1024.0);
|
printf("%s: compute allocated memory: %.2f MB\n", __func__, compute_memory_buffer_size /1024.0/1024.0);
|
||||||
}
|
}
|
||||||
|
|
||||||
return new_clip;
|
return new_clip;
|
||||||
|
@ -858,29 +882,23 @@ bool clip_image_batch_encode(const clip_ctx * ctx, const int n_threads, const cl
|
||||||
}
|
}
|
||||||
|
|
||||||
// reset alloc buffer to clean the memory from previous invocations
|
// reset alloc buffer to clean the memory from previous invocations
|
||||||
ggml_allocr_reset(ctx->alloc);
|
ggml_allocr_reset(ctx->compute_alloc);
|
||||||
|
|
||||||
// build the inference graph
|
// build the inference graph
|
||||||
ggml_cgraph * gf = clip_image_build_graph(ctx, imgs);
|
ggml_cgraph * gf = clip_image_build_graph(ctx, imgs);
|
||||||
ggml_allocr_alloc_graph(ctx->alloc, gf);
|
ggml_allocr_alloc_graph(ctx->compute_alloc, gf);
|
||||||
|
|
||||||
struct ggml_cplan plan = ggml_graph_plan(gf, n_threads);
|
if (ggml_backend_is_cpu(ctx->backend)) {
|
||||||
if (plan.work_size > 0) {
|
ggml_backend_cpu_set_n_threads(ctx->backend, n_threads);
|
||||||
plan.work_data = (uint8_t *)malloc(plan.work_size);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
ggml_graph_compute(gf, &plan);
|
ggml_backend_graph_compute(ctx->backend, gf);
|
||||||
|
|
||||||
// the last node is the embedding tensor
|
// the last node is the embedding tensor
|
||||||
struct ggml_tensor * embeddings = gf->nodes[gf->n_nodes - 1];
|
struct ggml_tensor * embeddings = gf->nodes[gf->n_nodes - 1];
|
||||||
|
|
||||||
// copy the embeddings to the location passed by the user
|
// copy the embeddings to the location passed by the user
|
||||||
memcpy(vec, ggml_get_data_f32(embeddings), ggml_nbytes(embeddings));
|
ggml_backend_tensor_get(embeddings, vec, 0, ggml_nbytes(embeddings));
|
||||||
|
|
||||||
if (plan.work_size > 0) {
|
|
||||||
free(plan.work_data);
|
|
||||||
}
|
|
||||||
|
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue