This commit is contained in:
slaren 2024-06-07 02:18:27 +02:00
parent 0425305d32
commit 2dd049ed45

View file

@ -13,10 +13,12 @@
#endif #endif
struct ggml_backend_blas_context { struct ggml_backend_blas_context {
int n_threads; int n_threads = GGML_DEFAULT_N_THREADS;
char * work_data; std::unique_ptr<char[]> work_data;
size_t work_size; size_t work_size = 0;
#ifndef GGML_USE_OPENMP
std::vector<std::future<void>> tasks; std::vector<std::future<void>> tasks;
#endif
}; };
// helper function to determine if it is better to use BLAS or not // helper function to determine if it is better to use BLAS or not
@ -76,11 +78,10 @@ static void ggml_backend_blas_mul_mat(ggml_backend_blas_context * ctx, struct gg
const size_t desired_wsize = type == GGML_TYPE_F32 ? 0 : ne03*ne02*ne_plane*sizeof(float); const size_t desired_wsize = type == GGML_TYPE_F32 ? 0 : ne03*ne02*ne_plane*sizeof(float);
if (ctx->work_size < desired_wsize) { if (ctx->work_size < desired_wsize) {
delete[] ctx->work_data; ctx->work_data.reset(new char[desired_wsize]);
ctx->work_data = new char[desired_wsize];
ctx->work_size = desired_wsize; ctx->work_size = desired_wsize;
} }
void * wdata = ctx->work_data; void * wdata = ctx->work_data.get();
// convert src0 to float // convert src0 to float
if (type != GGML_TYPE_F32) { if (type != GGML_TYPE_F32) {
@ -212,7 +213,6 @@ GGML_CALL static const char * ggml_backend_blas_name(ggml_backend_t backend) {
GGML_CALL static void ggml_backend_blas_free(ggml_backend_t backend) { GGML_CALL static void ggml_backend_blas_free(ggml_backend_t backend) {
ggml_backend_blas_context * ctx = (ggml_backend_blas_context *)backend->context; ggml_backend_blas_context * ctx = (ggml_backend_blas_context *)backend->context;
delete[] ctx->work_data;
delete ctx; delete ctx;
delete backend; delete backend;
} }
@ -306,11 +306,7 @@ static ggml_guid_t ggml_backend_blas_guid(void) {
} }
ggml_backend_t ggml_backend_blas_init(void) { ggml_backend_t ggml_backend_blas_init(void) {
ggml_backend_blas_context * ctx = new ggml_backend_blas_context{ ggml_backend_blas_context * ctx = new ggml_backend_blas_context;
/* .n_threads = */ GGML_DEFAULT_N_THREADS,
/* .work_data = */ NULL,
/* .work_size = */ 0,
};
ggml_backend_t backend = new ggml_backend { ggml_backend_t backend = new ggml_backend {
/* .guid = */ ggml_backend_blas_guid(), /* .guid = */ ggml_backend_blas_guid(),