llama : modified tensor permutations to multiply larger matrices during inference
This commit is contained in:
		
							parent
							
								
									202f323e66
								
							
						
					
					
						commit
						93c5937249
					
				
					 1 changed files with 10 additions and 10 deletions
				
			
		|  | @ -6529,13 +6529,13 @@ struct llm_build_context { | |||
|                 struct ggml_tensor * wk_b = ggml_view_3d(ctx0, model.layers[il].wk_b, n_embd_head_qk_nope, kv_lora_rank, n_head, ggml_row_size(model.layers[il].wk_b->type, n_embd_head_qk_nope), ggml_row_size(model.layers[il].wk_b->type, kv_lora_rank * n_embd_head_qk_nope), 0); | ||||
|                 cb(wk_b, "wk_b", il); | ||||
| 
 | ||||
|                 struct ggml_tensor * q_nope_perm = ggml_permute(ctx0, q_nope, 0, 2, 3, 1); | ||||
|                 struct ggml_tensor * q_nope_perm = ggml_permute(ctx0, q_nope, 0, 2, 1, 3); | ||||
|                 cb(q_nope_perm, "q_nope_perm", il); | ||||
| 
 | ||||
|                 struct ggml_tensor * q_nope2 = ggml_mul_mat(ctx0, wk_b, q_nope_perm); | ||||
|                 cb(q_nope2, "q_nope2", il); | ||||
| 
 | ||||
|                 struct ggml_tensor * q_nope2_perm = ggml_permute(ctx0, q_nope2, 0, 1, 3, 2); | ||||
|                 struct ggml_tensor * q_nope2_perm = ggml_permute(ctx0, q_nope2, 0, 2, 1, 3); | ||||
|                 cb(q_nope2_perm, "q_nope2_perm", il); | ||||
| 
 | ||||
|                 struct ggml_tensor * kq_nope = ggml_mul_mat(ctx0, kv_cache, q_nope2_perm); | ||||
|  | @ -6547,34 +6547,34 @@ struct llm_build_context { | |||
|                 struct ggml_tensor * kr_cache_perm = ggml_permute(ctx0, kr_cache, 0, 2, 3, 1); | ||||
|                 cb(kr_cache_perm, "kr_cache_perm", il); | ||||
| 
 | ||||
|                 struct ggml_tensor * kq_pe = ggml_mul_mat(ctx0, kr_cache, q_pe_perm); | ||||
|                 struct ggml_tensor * kq_pe = ggml_mul_mat(ctx0, kr_cache, q_pe); | ||||
|                 cb(kq_pe, "kq_pe", il); | ||||
| 
 | ||||
|                 struct ggml_tensor * kq = ggml_add(ctx0, kq_nope, kq_pe); | ||||
|                 cb(kq, "kq", il); | ||||
| 
 | ||||
|                 kq = ggml_permute(ctx0, kq, 0, 3, 1, 2); | ||||
|                 kq = ggml_cont(ctx0, ggml_permute(ctx0, kq, 0, 2, 1, 3)); | ||||
|                 cb(kq, "kq_perm", il); | ||||
| 
 | ||||
|                 struct ggml_tensor * wv_b = ggml_view_3d(ctx0, model.layers[il].wv_b, kv_lora_rank, n_embd_head_v, n_head, ggml_row_size(model.layers[il].wv_b->type, kv_lora_rank), ggml_row_size(model.layers[il].wv_b->type, kv_lora_rank * n_embd_head_v), 0); | ||||
|                 cb(wv_b, "wv_b", il); | ||||
| 
 | ||||
|                 kq = ggml_soft_max_ext(ctx0, kq, KQ_mask, kq_scale, hparams.f_max_alibi_bias); | ||||
|                 cb(kq, "kq_soft_max_ext", il); | ||||
| 
 | ||||
|                 struct ggml_tensor * kq_perm = ggml_permute(ctx0, kq, 0, 2, 3, 1); | ||||
|                 struct ggml_tensor * kq_perm = ggml_permute(ctx0, kq, 0, 2, 1, 3); | ||||
|                 cb(kq_perm, "kq_soft_max_ext_perm", il); | ||||
| 
 | ||||
|                 struct ggml_tensor * kqv_compressed = ggml_mul_mat(ctx0, kv_cache_trans, kq_perm); | ||||
|                 cb(kqv_compressed, "kqv_compressed", il); | ||||
| 
 | ||||
|                 kqv_compressed = ggml_permute(ctx0, kqv_compressed, 0, 1, 3, 2); | ||||
|                 kqv_compressed = ggml_permute(ctx0, kqv_compressed, 0, 2, 1, 3); | ||||
|                 cb(kqv_compressed, "kqv_compressed_perm", il); | ||||
| 
 | ||||
|                 struct ggml_tensor * wv_b = ggml_view_3d(ctx0, model.layers[il].wv_b, kv_lora_rank, n_embd_head_v, n_head, ggml_row_size(model.layers[il].wv_b->type, kv_lora_rank), ggml_row_size(model.layers[il].wv_b->type, kv_lora_rank * n_embd_head_v), 0); | ||||
|                 cb(wv_b, "wv_b", il); | ||||
| 
 | ||||
|                 struct ggml_tensor * kqv = ggml_mul_mat(ctx0, wv_b, kqv_compressed); | ||||
|                 cb(kqv, "kqv", il); | ||||
| 
 | ||||
|                 kqv = ggml_permute(ctx0, kqv, 0, 3, 1, 2); | ||||
|                 kqv = ggml_cont(ctx0, ggml_permute(ctx0, kqv, 0, 2, 1, 3)); | ||||
|                 cb(kqv, "kqv_perm", il); | ||||
| 
 | ||||
|                 cur = ggml_view_2d(ctx0, kqv, n_embd_head_v*n_head, n_tokens, ggml_row_size(kqv->type, n_embd_head_v*n_head), 0); | ||||
|  |  | |||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue