Free CUDA scratch buffer upon llama_model deletion
This commit is contained in:
parent
ed6587491c
commit
20e76a0764
3 changed files with 11 additions and 0 deletions
|
@ -2347,6 +2347,15 @@ void ggml_cuda_set_scratch_size(size_t scratch_size) {
|
|||
g_scratch_size = scratch_size;
|
||||
}
|
||||
|
||||
void ggml_cuda_free_scratch() {
|
||||
if (g_scratch_buffer == nullptr) {
|
||||
return;
|
||||
}
|
||||
|
||||
CUDA_CHECK(cudaFree(g_scratch_buffer));
|
||||
g_scratch_buffer = nullptr;
|
||||
}
|
||||
|
||||
bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor){
|
||||
ggml_cuda_func_t func;
|
||||
const bool any_on_device = tensor->backend == GGML_BACKEND_GPU
|
||||
|
|
|
@ -31,6 +31,7 @@ void ggml_cuda_assign_buffers(struct ggml_tensor * tensor);
|
|||
void ggml_cuda_assign_buffers_no_scratch(struct ggml_tensor * tensor);
|
||||
void ggml_cuda_set_main_device(int main_device);
|
||||
void ggml_cuda_set_scratch_size(size_t scratch_size);
|
||||
void ggml_cuda_free_scratch(void);
|
||||
bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor);
|
||||
|
||||
#ifdef __cplusplus
|
||||
|
|
|
@ -215,6 +215,7 @@ struct llama_model {
|
|||
for (size_t i = 0; i < tensors_by_name.size(); ++i) {
|
||||
ggml_cuda_free_data(tensors_by_name[i].second);
|
||||
}
|
||||
ggml_cuda_free_scratch();
|
||||
#elif defined(GGML_USE_CLBLAST)
|
||||
for (size_t i = 0; i < tensors_by_name.size(); ++i) {
|
||||
ggml_cl_free_data(tensors_by_name[i].second);
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue