From c95937642acf33da1ec6083a7a8ef8d0f3b8ca6c Mon Sep 17 00:00:00 2001 From: Galunid Date: Sun, 5 Nov 2023 07:40:12 +0100 Subject: [PATCH] Update after #3382 --- llama.cpp | 533 +++++++++++++----------------------------------------- 1 file changed, 123 insertions(+), 410 deletions(-) diff --git a/llama.cpp b/llama.cpp index 007683df3..6ebbefc12 100644 --- a/llama.cpp +++ b/llama.cpp @@ -3080,7 +3080,7 @@ static void llm_load_tensors( } break; case LLM_ARCH_STABLELM: { - model.tok_embeddings = ml.create_tensor(ctx, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, GGML_BACKEND_CPU); + model.tok_embd = ml.create_tensor(ctx, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, GGML_BACKEND_CPU); // output { @@ -3148,15 +3148,15 @@ static void llm_load_tensors( layer.ffn_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, backend); layer.ffn_norm_b = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}, backend); - layer.w1 = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, backend_split); - layer.w2 = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, backend_split); - layer.w3 = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, backend_split); + layer.ffn_gate = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, backend_split); + layer.ffn_down = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, backend_split); + layer.ffn_up = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, backend_split); if (backend == GGML_BACKEND_GPU) { vram_weights += - ggml_nbytes(layer.attn_norm) + ggml_nbytes(layer.wq) + ggml_nbytes(layer.wk) + - ggml_nbytes(layer.wv) + ggml_nbytes(layer.wo) + ggml_nbytes(layer.ffn_norm) + - ggml_nbytes(layer.w1) + ggml_nbytes(layer.w2) + ggml_nbytes(layer.w3); + ggml_nbytes(layer.attn_norm) + ggml_nbytes(layer.wq) + ggml_nbytes(layer.wk) + + ggml_nbytes(layer.wv) + ggml_nbytes(layer.wo) + ggml_nbytes(layer.ffn_norm) + + ggml_nbytes(layer.ffn_gate) + ggml_nbytes(layer.ffn_down) + ggml_nbytes(layer.ffn_up); } } } break; @@ -4639,6 +4639,119 @@ struct llm_build_context { return gf; } + + struct ggml_cgraph * build_stablelm() { + struct ggml_cgraph * gf = ggml_new_graph(ctx0); + + struct ggml_tensor * cur; + struct ggml_tensor * inpL; + + inpL = llm_build_inp_embd(ctx0, hparams, batch, model.tok_embd, cb); + cb(inpL, "inp_embd", -1); + + // inp_pos - contains the positions + struct ggml_tensor * inp_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens); + cb(inp_pos, "inp_pos", -1); + + // KQ_scale + struct ggml_tensor * KQ_scale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1); + cb(KQ_scale, "KQ_scale", -1); + + // KQ_mask (mask for 1 head, it will be broadcasted to all heads) + struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, n_tokens, 1); + cb(KQ_mask, "KQ_mask", -1); + + // shift the entire K-cache if needed + if (do_rope_shift) { + llm_build_k_shift(ctx0, hparams, cparams, kv_self, gf, LLM_ROPE_NEOX, n_ctx, hparams.n_rot, freq_base, freq_scale, cb); + } + + for (int il = 0; il < n_layer; ++il) { + struct ggml_tensor * inpSA = inpL; + + // norm + cur = llm_build_norm(ctx0, inpL, hparams, + model.layers[il].attn_norm, + model.layers[il].attn_norm_b, + LLM_NORM, cb, il); + cb(cur, "attn_norm", il); + + // self-attention + { + // compute Q and K and RoPE them + struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, model.layers[il].wq, cur); + cb(Qcur, "Qcur", il); + + struct ggml_tensor * Kcur = ggml_mul_mat(ctx0, model.layers[il].wk, cur); + cb(Kcur, "Kcur", il); + + struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, model.layers[il].wv, cur); + cb(Vcur, "Vcur", il); + + Qcur = ggml_rope_custom( + ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, + hparams.n_rot, 2, 0, n_orig_ctx, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + cb(Qcur, "Qcur", il); + + Kcur = ggml_rope_custom( + ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, + hparams.n_rot, 2, 0, n_orig_ctx, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + cb(Kcur, "Kcur", il); + + llm_build_kv_store(ctx0, hparams, kv_self, gf, Kcur, Vcur, n_ctx, n_tokens, kv_head, cb, il); + + cur = llm_build_kqv(ctx0, hparams, kv_self, + model.layers[il].wo, NULL, + Qcur, KQ_scale, KQ_mask, n_ctx, n_tokens, n_kv, -1.0f, cb, il); + cb(cur, "kqv_out", il); + } + + struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + + // feed-forward network + { + cur = llm_build_norm(ctx0, ffn_inp, hparams, + model.layers[il].ffn_norm, + model.layers[il].ffn_norm_b, + LLM_NORM, cb, il); + cb(cur, "ffn_norm", il); + + cur = llm_build_ffn(ctx0, cur, + model.layers[il].ffn_up, NULL, + model.layers[il].ffn_gate, NULL, + model.layers[il].ffn_down, NULL, + LLM_FFN_SILU, LLM_FFN_PAR, cb, il); + cb(cur, "ffn_out", il); + } + + cur = ggml_add(ctx0, cur, ffn_inp); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + cur = inpL; + + cur = llm_build_norm(ctx0, cur, hparams, + model.output_norm, + model.output_norm_b, + LLM_NORM, cb, -1); + cb(cur, "result_norm", -1); + + // lm_head + cur = ggml_mul_mat(ctx0, model.output, cur); + cb(cur, "result_output", -1); + + ggml_build_forward_expand(gf, cur); + + return gf; + } }; // @@ -4810,408 +4923,7 @@ static const std::unordered_map k_offload_map static llm_offload_trie k_offload_func_trie(k_offload_map); -static struct ggml_cgraph * llm_build_stablelm( - llama_context & lctx, - const llama_batch & batch) { - const auto & model = lctx.model; - const auto & hparams = model.hparams; - const auto & cparams = lctx.cparams; - const auto & kv_self = lctx.kv_self; - - GGML_ASSERT(!!kv_self.ctx); - - const int64_t n_embd = hparams.n_embd; - const int64_t n_layer = hparams.n_layer; - const int64_t n_ctx = cparams.n_ctx; - const int64_t n_rot = hparams.n_rot; - const int64_t n_head = hparams.n_head; - const int64_t n_head_kv = hparams.n_head_kv; - const int64_t n_embd_head = hparams.n_embd_head(); - const int64_t n_embd_gqa = hparams.n_embd_gqa(); - - const float freq_base = cparams.rope_freq_base; - const float freq_scale = cparams.rope_freq_scale; - const float norm_eps = hparams.f_norm_eps; - - const int n_gpu_layers = model.n_gpu_layers; - - const int32_t n_tokens = batch.n_tokens; - const int32_t n_kv = ggml_allocr_is_measure(lctx.alloc) ? n_ctx : kv_self.n; - const int32_t kv_head = ggml_allocr_is_measure(lctx.alloc) ? n_ctx - n_tokens : kv_self.head; - - const bool do_rope_shift = ggml_allocr_is_measure(lctx.alloc) || kv_self.has_shift; - - //printf("n_kv = %d\n", n_kv); - - auto & buf_compute = lctx.buf_compute; - - struct ggml_init_params params = { - /*.mem_size =*/ buf_compute.size, - /*.mem_buffer =*/ buf_compute.data, - /*.no_alloc =*/ true, - }; - - struct ggml_context * ctx0 = ggml_init(params); - - ggml_cgraph * gf = ggml_new_graph(ctx0); - - struct ggml_tensor * cur; - struct ggml_tensor * inpL; - - if (batch.token) { - struct ggml_tensor * inp_tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens); - - ggml_allocr_alloc(lctx.alloc, inp_tokens); - if (!ggml_allocr_is_measure(lctx.alloc)) { - memcpy(inp_tokens->data, batch.token, n_tokens*ggml_element_size(inp_tokens)); - } - ggml_set_name(inp_tokens, "inp_tokens"); - - inpL = ggml_get_rows(ctx0, model.tok_embeddings, inp_tokens); - } else { -#ifdef GGML_USE_MPI - GGML_ASSERT(false && "not implemented"); -#endif - - inpL = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, n_tokens); - - ggml_allocr_alloc(lctx.alloc, inpL); - if (!ggml_allocr_is_measure(lctx.alloc)) { - memcpy(inpL->data, batch.embd, n_tokens * n_embd * ggml_element_size(inpL)); - } - } - - const int i_gpu_start = n_layer - n_gpu_layers; - (void) i_gpu_start; - - // offload functions set the tensor output backend to GPU - // tensors are GPU-accelerated if any input or the output has been offloaded - offload_func_t offload_func_nr = llama_nop; // nr = non-repeating - offload_func_t offload_func_kq = llama_nop; - offload_func_t offload_func_v = llama_nop; - -#ifdef GGML_USE_CUBLAS - if (n_gpu_layers > n_layer) { - offload_func_nr = ggml_cuda_assign_buffers_no_alloc; - } - if (n_gpu_layers > n_layer + 1) { - offload_func_v = ggml_cuda_assign_buffers_no_alloc; - } - if (n_gpu_layers > n_layer + 2) { - offload_func_kq = ggml_cuda_assign_buffers_no_alloc; - } -#endif // GGML_USE_CUBLAS - - // KQ_scale - struct ggml_tensor * KQ_scale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1); - ggml_set_name(KQ_scale, "1/sqrt(n_embd_head)"); - ggml_allocr_alloc(lctx.alloc, KQ_scale); - if (!ggml_allocr_is_measure(lctx.alloc)) { - ggml_set_f32(KQ_scale, 1.0f/sqrtf(float(n_embd_head))); - } - - // KQ_mask (mask for 1 head, it will be broadcasted to all heads) - struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, n_tokens, 1); - offload_func_kq(KQ_mask); - ggml_set_name(KQ_mask, "KQ_mask"); - ggml_allocr_alloc(lctx.alloc, KQ_mask); - if (!ggml_allocr_is_measure(lctx.alloc)) { - float * data = (float *) KQ_mask->data; - memset(data, 0, ggml_nbytes(KQ_mask)); - - for (int h = 0; h < 1; ++h) { - for (int j = 0; j < n_tokens; ++j) { - const llama_pos pos = batch.pos[j]; - const llama_seq_id seq_id = batch.seq_id[j][0]; - - for (int i = 0; i < n_kv; ++i) { - if (!kv_self.cells[i].has_seq_id(seq_id) || kv_self.cells[i].pos > pos) { - data[h*(n_kv*n_tokens) + j*n_kv + i] = -INFINITY; - } - } - } - } - } - - // KQ_pos - contains the positions - struct ggml_tensor * KQ_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens); - offload_func_kq(KQ_pos); - ggml_set_name(KQ_pos, "KQ_pos"); - ggml_allocr_alloc(lctx.alloc, KQ_pos); - if (!ggml_allocr_is_measure(lctx.alloc)) { - int * data = (int *) KQ_pos->data; - for (int i = 0; i < n_tokens; ++i) { - data[i] = batch.pos[i]; - } - } - - // shift the entire K-cache if needed - if (do_rope_shift) { - struct ggml_tensor * K_shift = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_ctx); - offload_func_kq(K_shift); - ggml_set_name(K_shift, "K_shift"); - ggml_allocr_alloc(lctx.alloc, K_shift); - if (!ggml_allocr_is_measure(lctx.alloc)) { - int * data = (int *) K_shift->data; - for (int i = 0; i < n_ctx; ++i) { - data[i] = kv_self.cells[i].delta; - } - } - - for (int il = 0; il < n_layer; ++il) { - struct ggml_tensor * tmp = - ggml_rope_custom_inplace(ctx0, - ggml_view_3d(ctx0, kv_self.k, - n_embd_head, n_head_kv, n_ctx, - ggml_element_size(kv_self.k)*n_embd_head, - ggml_element_size(kv_self.k)*n_embd_gqa, - ggml_element_size(kv_self.k)*n_embd_gqa*n_ctx*il), - K_shift, n_rot, 2, 0, freq_base, freq_scale); - offload_func_kq(tmp); - ggml_build_forward_expand(gf, tmp); - } - } - - for (int il = 0; il < n_layer; ++il) { - ggml_format_name(inpL, "layer_inp_%d", il); - - offload_func_t offload_func = llama_nop; - -#ifdef GGML_USE_CUBLAS - if (il >= i_gpu_start) { - offload_func = ggml_cuda_assign_buffers_no_alloc; - } -#endif // GGML_USE_CUBLAS - - struct ggml_tensor * inpSA = inpL; - - // norm - { - cur = ggml_norm(ctx0, inpL, norm_eps); - offload_func(cur); - ggml_set_name(cur, "norm_0"); - - // cur = cur*attn_norm(broadcasted) - cur = ggml_mul(ctx0, cur, model.layers[il].attn_norm); - offload_func(cur); - ggml_set_name(cur, "attention_norm_0"); - - // add bias - cur = ggml_add(ctx0, cur, model.layers[il].attn_norm_b); - offload_func(cur); - ggml_set_name(cur, "attention_norm_0"); - } - - // self-attention - { - // compute Q and K and RoPE them - struct ggml_tensor * tmpk = ggml_mul_mat(ctx0, model.layers[il].wk, cur); - offload_func_kq(tmpk); - ggml_set_name(tmpk, "tmpk"); - - struct ggml_tensor * tmpq = ggml_mul_mat(ctx0, model.layers[il].wq, cur); - offload_func_kq(tmpq); - ggml_set_name(tmpq, "tmpq"); - - struct ggml_tensor * Kcur = ggml_rope_custom(ctx0, ggml_reshape_3d(ctx0, tmpk, n_embd_head, n_head_kv, n_tokens), KQ_pos, n_rot, 2, 0, freq_base, freq_scale); - offload_func_kq(Kcur); - ggml_set_name(Kcur, "Kcur"); - - struct ggml_tensor * Qcur = ggml_rope_custom(ctx0, ggml_reshape_3d(ctx0, tmpq, n_embd_head, n_head, n_tokens), KQ_pos, n_rot, 2, 0, freq_base, freq_scale); - offload_func_kq(Qcur); - ggml_set_name(Qcur, "Qcur"); - - // store key and value to memory - { - // compute the transposed [n_tokens, n_embd] V matrix - - struct ggml_tensor * tmpv = ggml_mul_mat(ctx0, model.layers[il].wv, cur); - offload_func_v(tmpv); - ggml_set_name(tmpv, "tmpv"); - - struct ggml_tensor * Vcur = ggml_transpose(ctx0, ggml_reshape_2d(ctx0, tmpv, n_embd_gqa, n_tokens)); - offload_func_v(Vcur); - ggml_set_name(Vcur, "Vcur"); - - struct ggml_tensor * k = ggml_view_1d(ctx0, kv_self.k, n_tokens*n_embd_gqa, (ggml_element_size(kv_self.k)*n_embd_gqa)*(il*n_ctx + kv_head)); - offload_func_kq(k); - ggml_set_name(k, "k"); - - struct ggml_tensor * v = ggml_view_2d(ctx0, kv_self.v, n_tokens, n_embd_gqa, - ( n_ctx)*ggml_element_size(kv_self.v), - (il*n_ctx)*ggml_element_size(kv_self.v)*n_embd_gqa + kv_head*ggml_element_size(kv_self.v)); - offload_func_v(v); - ggml_set_name(v, "v"); - - // important: storing RoPE-ed version of K in the KV cache! - ggml_build_forward_expand(gf, ggml_cpy(ctx0, Kcur, k)); - ggml_build_forward_expand(gf, ggml_cpy(ctx0, Vcur, v)); - } - - struct ggml_tensor * Q = ggml_permute(ctx0, Qcur, 0, 2, 1, 3); - offload_func_kq(Q); - ggml_set_name(Q, "Q"); - - struct ggml_tensor * K = - ggml_view_3d(ctx0, kv_self.k, - n_embd_head, n_kv, n_head_kv, - ggml_element_size(kv_self.k)*n_embd_gqa, - ggml_element_size(kv_self.k)*n_embd_head, - ggml_element_size(kv_self.k)*n_embd_gqa*n_ctx*il); - offload_func_kq(K); - ggml_set_name(K, "K"); - - // K * Q - struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q); - offload_func_kq(KQ); - ggml_set_name(KQ, "KQ"); - - // KQ_scaled = KQ / sqrt(n_embd_head) - // KQ_scaled shape [n_kv, n_tokens, n_head, 1] - struct ggml_tensor * KQ_scaled = ggml_scale(ctx0, KQ, KQ_scale); - offload_func_kq(KQ_scaled); - ggml_set_name(KQ_scaled, "KQ_scaled"); - - // KQ_masked = mask_past(KQ_scaled) - struct ggml_tensor * KQ_masked = ggml_add(ctx0, KQ_scaled, KQ_mask); - offload_func_kq(KQ_masked); - ggml_set_name(KQ_masked, "KQ_masked"); - - // KQ = soft_max(KQ_masked) - struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ_masked); - offload_func_v(KQ_soft_max); - ggml_set_name(KQ_soft_max, "KQ_soft_max"); - - // split cached V into n_head heads - struct ggml_tensor * V = - ggml_view_3d(ctx0, kv_self.v, - n_kv, n_embd_head, n_head_kv, - ggml_element_size(kv_self.v)*n_ctx, - ggml_element_size(kv_self.v)*n_ctx*n_embd_head, - ggml_element_size(kv_self.v)*n_ctx*n_embd_gqa*il); - offload_func_v(V); - ggml_set_name(V, "V"); - -#if 1 - struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max); - offload_func_v(KQV); - ggml_set_name(KQV, "KQV"); -#else - // make V contiguous in memory to speed up the matmul, however we waste time on the copy - // on M1 this is faster for the perplexity computation, but ~5% slower for the single-token generation - // is there a better way? - struct ggml_tensor * V_cont = ggml_cpy(ctx0, V, ggml_new_tensor_3d(ctx0, kv_self.v->type, n_ctx, n_embd_head, n_head)); - struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V_cont, KQ_soft_max); -#endif - - // KQV_merged = KQV.permute(0, 2, 1, 3) - struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3); - offload_func_v(KQV_merged); - ggml_set_name(KQV_merged, "KQV_merged"); - - // cur = KQV_merged.contiguous().view(n_embd, n_tokens) - cur = ggml_cont_2d(ctx0, KQV_merged, n_embd, n_tokens); - offload_func_v(cur); - ggml_set_name(cur, "KQV_merged_contiguous"); - - // projection (no bias) - cur = ggml_mul_mat(ctx0, - model.layers[il].wo, - cur); - offload_func(cur); - ggml_set_name(cur, "result_wo"); - } - - struct ggml_tensor * inpFF = ggml_add(ctx0, cur, inpSA); - offload_func(inpFF); - ggml_set_name(inpFF, "inpFF"); - - // feed-forward network - { - // norm - { - cur = ggml_norm(ctx0, inpFF, norm_eps); - offload_func(cur); - ggml_set_name(cur, "norm_1"); - - // cur = cur*ffn_norm(broadcasted) - cur = ggml_mul(ctx0, cur, model.layers[il].ffn_norm); - offload_func(cur); - ggml_set_name(cur, "ffn_norm"); - - // add bias - // cur = cur*ffn_norm(broadcasted) - cur = ggml_add(ctx0, cur, model.layers[il].ffn_norm_b); - offload_func(cur); - ggml_set_name(cur, "ffn_norm"); - } - - struct ggml_tensor * tmp = ggml_mul_mat(ctx0, - model.layers[il].w3, - cur); - offload_func(tmp); - ggml_set_name(tmp, "result_w3"); - - cur = ggml_mul_mat(ctx0, - model.layers[il].w1, - cur); - offload_func(cur); - ggml_set_name(cur, "result_w1"); - - // SILU activation - cur = ggml_silu(ctx0, cur); - offload_func(cur); - ggml_set_name(cur, "silu"); - - cur = ggml_mul(ctx0, cur, tmp); - offload_func(cur); - ggml_set_name(cur, "silu_x_result_w3"); - - cur = ggml_mul_mat(ctx0, - model.layers[il].w2, - cur); - offload_func(cur); - ggml_set_name(cur, "result_w2"); - } - - cur = ggml_add(ctx0, cur, inpFF); - offload_func(cur); - ggml_set_name(cur, "inpFF_+_result_w2"); - - // input for next layer - inpL = cur; - } - - cur = inpL; - - // norm - { - cur = ggml_norm(ctx0, cur, norm_eps); - offload_func_nr(cur); - ggml_set_name(cur, "norm_2"); - - // cur = cur*norm(broadcasted) - cur = ggml_mul(ctx0, cur, model.output_norm); - // offload_func_nr(cur); // TODO CPU + GPU mirrored backend - // ggml_set_name(cur, "result_norm"); - - cur = ggml_add(ctx0, cur, model.output_norm_b); - offload_func_nr(cur); - ggml_set_name(cur, "result_norm"); - - } - - // lm_head - cur = ggml_mul_mat(ctx0, model.output, cur); - ggml_set_name(cur, "result_output"); - - ggml_build_forward_expand(gf, cur); - - ggml_free(ctx0); - - return gf; -} static struct ggml_cgraph * llama_build_graph( llama_context & lctx, @@ -5513,7 +5225,7 @@ static struct ggml_cgraph * llama_build_graph( } break; case LLM_ARCH_STABLELM: { - result = llm_build_stablelm(lctx, batch); + result = llm.build_stablelm(); } break; default: GGML_ASSERT(false); @@ -5689,7 +5401,8 @@ static int llama_decode_internal( model.arch == LLM_ARCH_BAICHUAN || model.arch == LLM_ARCH_FALCON || model.arch == LLM_ARCH_REFACT || - model.arch == LLM_ARCH_MPT; + model.arch == LLM_ARCH_MPT || + model.arch == LLM_ARCH_STABLELM; const bool fully_offloaded = model.n_gpu_layers >= (int) hparams.n_layer + 3; if (ggml_cpu_has_cublas() && full_offload_supported && fully_offloaded) {