MPT : support GQA for replit-code-v1.5 (#3627)
This commit is contained in:
		
							parent
							
								
									11dc1091f6
								
							
						
					
					
						commit
						11bff29045
					
				
					 2 changed files with 5 additions and 3 deletions
				
			
		|  | @ -2839,8 +2839,8 @@ static void llm_load_tensors( | |||
|                         auto & layer = model.layers[i]; | ||||
| 
 | ||||
|                         layer.attn_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, backend); | ||||
|                         layer.wqkv = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, 3*n_embd}, backend_split); | ||||
|                         layer.wo = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd},     backend_split); | ||||
|                         layer.wqkv = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, backend_split); | ||||
|                         layer.wo   = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd},                backend_split); | ||||
| 
 | ||||
|                         layer.ffn_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, backend); | ||||
| 
 | ||||
|  | @ -5368,7 +5368,7 @@ static struct ggml_cgraph * llm_build_mpt( | |||
|     const int64_t n_layer     = hparams.n_layer; | ||||
|     const int64_t n_ctx       = cparams.n_ctx; | ||||
|     const int64_t n_head      = hparams.n_head; | ||||
|     const int64_t n_head_kv   = hparams.n_head_kv; // == n_head for MPT, as there's no MQA/GQA
 | ||||
|     const int64_t n_head_kv   = hparams.n_head_kv; | ||||
|     const int64_t n_embd_head = hparams.n_embd_head(); | ||||
|     const int64_t n_embd_gqa  = hparams.n_embd_gqa(); | ||||
| 
 | ||||
|  |  | |||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue