llama : add phi-2 + fix NeoX rope + ggml_mul_mat_set_prec (#4490)
* phi2 implementation * fix breaking change * phi-2 : various fixes * phi-2 : use layer norm eps * py : whitespaces * llama : fix meta KV override bug * convert : phi don't add BOS token * convert : revert "added_tokens_decoder" change * phi-2 : scale Q instead of KQ for better precision * ggml : fix NeoX rope to rotate just first n_dims * cuda : less diff in the rope_neox kernel * ggml : add ggml_mul_mat_set_prec ggml-ci * Update ggml-cuda.cu Co-authored-by: slaren <slarengh@gmail.com> * Update ggml-cuda.cu Co-authored-by: slaren <slarengh@gmail.com> * cuda : ggml_cuda_op_mul_mat_cublas support F32 precision * cuda : remove oboslete comment --------- Co-authored-by: Ebey Abraham <ebeyabraham@microsoft.com> Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> Co-authored-by: slaren <slarengh@gmail.com>
This commit is contained in:
parent
3c04bf6da8
commit
b9e74f9bca
9 changed files with 463 additions and 76 deletions
|
@ -1555,6 +1555,7 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
|
|||
test_cases.emplace_back(new test_rope(type, { 64, 8, 10, 1}, 64, 2, 512)); // neox (falcon 40B)
|
||||
test_cases.emplace_back(new test_rope(type, { 64, 128, 10, 1}, 64, 2, 512)); // neox (falcon 40B)
|
||||
test_cases.emplace_back(new test_rope(type, { 80, 32, 10, 1}, 20, 2, 512)); // neox (stablelm)
|
||||
test_cases.emplace_back(new test_rope(type, { 80, 32, 10, 1}, 32, 2, 512)); // neox (phi-2)
|
||||
}
|
||||
|
||||
test_cases.emplace_back(new test_alibi());
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue