add ggml_backend_buffer_clear

zero-init KV cache buffer
This commit is contained in:
slaren 2023-12-19 18:31:00 +01:00
parent 0c5ee7c417
commit 1ac01fbbd1
6 changed files with 44 additions and 16 deletions

View file

@ -40,6 +40,7 @@ extern "C" {
// (optional) copy tensor between different buffer-type, allow for single-copy tranfers
void (*cpy_tensor_from)(ggml_backend_buffer_t buffer, struct ggml_tensor * src, struct ggml_tensor * dst);
void (*cpy_tensor_to) (ggml_backend_buffer_t buffer, struct ggml_tensor * src, struct ggml_tensor * dst);
void (*clear) (ggml_backend_buffer_t buffer, uint8_t value);
};
struct ggml_backend_buffer {

View file

@ -94,6 +94,10 @@ size_t ggml_backend_buffer_get_alloc_size(ggml_backend_buffer_t buffer, struct g
return ggml_backend_buft_get_alloc_size(ggml_backend_buffer_type(buffer), tensor);
}
void ggml_backend_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) {
buffer->iface.clear(buffer, value);
}
ggml_backend_buffer_type_t ggml_backend_buffer_type(ggml_backend_buffer_t buffer) {
return buffer->buft;
}
@ -410,6 +414,10 @@ static void ggml_backend_cpu_buffer_cpy_tensor_to(ggml_backend_buffer_t buffer,
GGML_UNUSED(buffer);
}
static void ggml_backend_cpu_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) {
memset(buffer->context, value, buffer->size);
}
static struct ggml_backend_buffer_i cpu_backend_buffer_i = {
/* .free_buffer = */ ggml_backend_cpu_buffer_free_buffer,
/* .get_base = */ ggml_backend_cpu_buffer_get_base,
@ -418,6 +426,7 @@ static struct ggml_backend_buffer_i cpu_backend_buffer_i = {
/* .get_tensor = */ ggml_backend_cpu_buffer_get_tensor,
/* .cpy_tensor_from = */ ggml_backend_cpu_buffer_cpy_tensor_from,
/* .cpy_tensor_to = */ ggml_backend_cpu_buffer_cpy_tensor_to,
/* .clear = */ ggml_backend_cpu_buffer_clear,
};
// for buffers from ptr, free is not called
@ -429,6 +438,7 @@ static struct ggml_backend_buffer_i cpu_backend_buffer_i_from_ptr = {
/* .get_tensor = */ ggml_backend_cpu_buffer_get_tensor,
/* .cpy_tensor_from = */ ggml_backend_cpu_buffer_cpy_tensor_from,
/* .cpy_tensor_to = */ ggml_backend_cpu_buffer_cpy_tensor_to,
/* .clear = */ ggml_backend_cpu_buffer_clear,
};
static const size_t TENSOR_ALIGNMENT = 64; // should be enough for AVX 512

View file

@ -29,6 +29,7 @@ extern "C" {
GGML_API void ggml_backend_buffer_init_tensor (ggml_backend_buffer_t buffer, struct ggml_tensor * tensor);
GGML_API size_t ggml_backend_buffer_get_alignment (ggml_backend_buffer_t buffer);
GGML_API size_t ggml_backend_buffer_get_alloc_size(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor);
GGML_API void ggml_backend_buffer_clear (ggml_backend_buffer_t buffer, uint8_t value);
GGML_API ggml_backend_buffer_type_t ggml_backend_buffer_type(ggml_backend_buffer_t buffer);
//

View file

@ -9494,6 +9494,15 @@ static void ggml_backend_cuda_buffer_get_tensor(ggml_backend_buffer_t buffer, co
CUDA_CHECK(cudaMemcpy(data, (const char *)tensor->data + offset, size, cudaMemcpyDeviceToHost));
}
static void ggml_backend_cuda_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) {
ggml_backend_buffer_context_cuda * ctx = (ggml_backend_buffer_context_cuda *)buffer->context;
ggml_cuda_set_device(ctx->device);
CUDA_CHECK(cudaDeviceSynchronize());
CUDA_CHECK(cudaMemset(ctx->dev_ptr, value, buffer->size));
}
static struct ggml_backend_buffer_i cuda_backend_buffer_interface = {
/* .free_buffer = */ ggml_backend_cuda_buffer_free_buffer,
/* .get_base = */ ggml_backend_cuda_buffer_get_base,
@ -9502,6 +9511,7 @@ static struct ggml_backend_buffer_i cuda_backend_buffer_interface = {
/* .get_tensor = */ ggml_backend_cuda_buffer_get_tensor,
/* .cpy_tensor_from = */ NULL,
/* .cpy_tensor_to = */ NULL,
/* .clear = */ ggml_backend_cuda_buffer_clear,
};
// cuda buffer type

View file

@ -2429,8 +2429,6 @@ static void ggml_backend_metal_buffer_free_buffer(ggml_backend_buffer_t buffer)
}
free(ctx);
UNUSED(buffer);
}
static void ggml_backend_metal_buffer_set_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
@ -2457,6 +2455,12 @@ static void ggml_backend_metal_buffer_cpy_tensor_to(ggml_backend_buffer_t buffer
UNUSED(buffer);
}
static void ggml_backend_metal_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) {
struct ggml_backend_metal_buffer_context * ctx = (struct ggml_backend_metal_buffer_context *)buffer->context;
memset(ctx->all_data, value, ctx->all_size);
}
static struct ggml_backend_buffer_i ggml_backend_metal_buffer_i = {
/* .free_buffer = */ ggml_backend_metal_buffer_free_buffer,
/* .get_base = */ ggml_backend_metal_buffer_get_base,
@ -2465,6 +2469,7 @@ static struct ggml_backend_buffer_i ggml_backend_metal_buffer_i = {
/* .get_tensor = */ ggml_backend_metal_buffer_get_tensor,
/* .cpy_tensor_from = */ ggml_backend_metal_buffer_cpy_tensor_from,
/* .cpy_tensor_to = */ ggml_backend_metal_buffer_cpy_tensor_to,
/* .clear = */ ggml_backend_metal_buffer_clear,
};
// default buffer type

View file

@ -1551,9 +1551,8 @@ static bool llama_kv_cache_init(
// buf may be NULL with full offload
if (cache.buf) {
// TODO: ggml_backend_buffer_memset
// this is only valid with CPU buffers!
//memset(ggml_backend_buffer_get_base(cache.buf), 0, ggml_backend_buffer_get_size(cache.buf));
// initialize the buffer to avoid NaNs in the padding
ggml_backend_buffer_clear(cache.buf, 0);
}
if (vram_kv_cache > 0) {
@ -3569,8 +3568,12 @@ static void llm_load_tensors(
{
size_t sys_mem_required = ctx_size + buf_size;
{
LLAMA_LOG_INFO("%s: system memory used = %7.2f MiB\n", __func__, sys_mem_required / 1024.0 / 1024.0);
}
if (vram_weights > 0) {
LLAMA_LOG_INFO("%s: VRAM used = %7.2f MiB\n", __func__, vram_weights / 1024.0 / 1024.0);
}
#if defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST)
const int n_gpu = std::min(n_gpu_layers, int(hparams.n_layer));
@ -3586,7 +3589,6 @@ static void llm_load_tensors(
LLAMA_LOG_INFO("%s: offloaded %d/%d layers to GPU\n", __func__, std::min(n_gpu_layers, max_offloadable_layers), max_backend_supported_layers);
#else
GGML_UNUSED(n_gpu_layers);
GGML_UNUSED(vram_weights);
GGML_UNUSED(tensor_split);
#endif // defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST)
}
@ -3601,7 +3603,6 @@ static void llm_load_tensors(
ggml_cuda_set_tensor_split(tensor_split);
#endif // GGML_USE_CUBLAS
// TODO: only pass buf if it is a mmap buffer
ml.load_all_data(ctx, progress_callback, progress_callback_user_data, buf_mmap, use_mlock ? &model.mlock_mmap : NULL);
if (progress_callback) {