add metal backend

This commit is contained in:
FSSRepo 2023-12-29 10:32:40 -05:00
parent a52154d3b3
commit 2cf4f37e36
2 changed files with 24 additions and 2 deletions

View file

@ -38,3 +38,6 @@ target_compile_features(llava PRIVATE cxx_std_11)
if(LLAMA_CUBLAS) if(LLAMA_CUBLAS)
add_definitions(-DCLIP_USE_CUBLAS) add_definitions(-DCLIP_USE_CUBLAS)
endif() endif()
if(LLAMA_METAL)
add_definitions(-DCLIP_USE_METAL)
endif()

View file

@ -22,6 +22,10 @@
#include "ggml-cuda.h" #include "ggml-cuda.h"
#endif #endif
#ifdef CLIP_USE_METAL
#include "ggml-metal.h"
#endif
#define STB_IMAGE_IMPLEMENTATION #define STB_IMAGE_IMPLEMENTATION
#include "stb_image.h" #include "stb_image.h"
@ -512,6 +516,11 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) {
printf("CLIP using CUDA backend\n"); printf("CLIP using CUDA backend\n");
#endif #endif
#ifdef CLIP_USE_METAL
new_clip->backend = ggml_backend_metal_init();
printf("CLIP using Metal backend\n");
#endif
if(!new_clip->backend) { if(!new_clip->backend) {
new_clip->backend = ggml_backend_cpu_init(); new_clip->backend = ggml_backend_cpu_init();
printf("CLIP using CPU backend\n"); printf("CLIP using CPU backend\n");
@ -594,7 +603,11 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) {
return nullptr; return nullptr;
} }
int num_bytes = ggml_nbytes(cur); int num_bytes = ggml_nbytes(cur);
if (ggml_backend_is_cpu(new_clip->backend)) { if (ggml_backend_is_cpu(new_clip->backend)
#ifdef CLIP_USE_METAL
|| ggml_backend_is_metal(new_clip->backend)
#endif
) {
// for the CPU and Metal backend, we can read directly into the tensor // for the CPU and Metal backend, we can read directly into the tensor
fin.read(reinterpret_cast<char *>(cur->data), num_bytes); fin.read(reinterpret_cast<char *>(cur->data), num_bytes);
} else { } else {
@ -882,7 +895,13 @@ bool clip_image_batch_encode(const clip_ctx * ctx, const int n_threads, const cl
ggml_backend_cpu_set_n_threads(ctx->backend, n_threads); ggml_backend_cpu_set_n_threads(ctx->backend, n_threads);
} }
ggml_backend_graph_compute(ctx->backend, gf); #ifdef CLIP_USE_METAL
if (ggml_backend_is_metal(ctx->backend)) {
ggml_backend_metal_set_n_cb(ctx->backend, n_threads);
}
#endif
ggml_backend_graph_compute(ctx->backend, gf);
// the last node is the embedding tensor // the last node is the embedding tensor
struct ggml_tensor * embeddings = gf->nodes[gf->n_nodes - 1]; struct ggml_tensor * embeddings = gf->nodes[gf->n_nodes - 1];