fix offloading logic

This commit is contained in:
JohannesGaessler 2023-08-27 11:21:26 +02:00
parent e23fa92401
commit 21df40d0c4

View file

@ -2233,7 +2233,6 @@ static struct ggml_cgraph * llm_build_llama(
offload_func_t offload_func_nr = llama_nop; // nr = non-repeating offload_func_t offload_func_nr = llama_nop; // nr = non-repeating
offload_func_t offload_func_kq = llama_nop; offload_func_t offload_func_kq = llama_nop;
offload_func_t offload_func_v = llama_nop; offload_func_t offload_func_v = llama_nop;
offload_func_t offload_func_skip = llama_nop;
#ifdef GGML_USE_CUBLAS #ifdef GGML_USE_CUBLAS
if (n_gpu_layers > n_layer) { if (n_gpu_layers > n_layer) {
@ -2245,9 +2244,6 @@ static struct ggml_cgraph * llm_build_llama(
if (n_gpu_layers > n_layer + 2) { if (n_gpu_layers > n_layer + 2) {
offload_func_kq = ggml_cuda_assign_buffers_no_alloc; offload_func_kq = ggml_cuda_assign_buffers_no_alloc;
} }
if (n_gpu_layers > 0) {
offload_func_skip = ggml_cuda_assign_buffers_no_alloc;
}
#endif // GGML_USE_CUBLAS #endif // GGML_USE_CUBLAS
struct ggml_tensor * KQ_scale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1); struct ggml_tensor * KQ_scale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1);
@ -2331,11 +2327,11 @@ static struct ggml_cgraph * llm_build_llama(
// otherwise for Metal we'd have to rebuild the concurrency list. // otherwise for Metal we'd have to rebuild the concurrency list.
cur = ggml_view_2d(ctx0, cur, n_embd, 1, cur->nb[1], (N - 1)*ggml_element_size(cur)*n_embd); cur = ggml_view_2d(ctx0, cur, n_embd, 1, cur->nb[1], (N - 1)*ggml_element_size(cur)*n_embd);
offload_func_skip(cur); offload_func_kq(cur);
ggml_set_name(cur, "cur-lastpos"); ggml_set_name(cur, "cur-lastpos");
inpSA = ggml_view_2d(ctx0, inpSA, n_embd, 1, inpSA->nb[1], (N - 1)*ggml_element_size(inpSA)*n_embd); inpSA = ggml_view_2d(ctx0, inpSA, n_embd, 1, inpSA->nb[1], (N - 1)*ggml_element_size(inpSA)*n_embd);
offload_func_skip(inpSA); offload_func(inpSA);
ggml_set_name(inpSA, "inpSA-lastpos"); ggml_set_name(inpSA, "inpSA-lastpos");
n_past += N - 1; n_past += N - 1;