phi-2 : various fixes

This commit is contained in:
Georgi Gerganov 2023-12-16 10:46:18 +02:00
parent e20765534d
commit a2a3d2c8d7
No known key found for this signature in database
GPG key ID: 449E073F9DC10735
5 changed files with 66 additions and 37 deletions

View file

@ -4998,7 +4998,16 @@ static __global__ void rope_neox(
const int ib = col / n_dims;
const int ic = col % n_dims;
const int i = row*ncols + ib*n_dims + ic/2;
// IMPORTANT: consider the case ncols == 80 and n_dims == 32 (phi-2)
// I don't know what we are supposed to compute, because the row is not divisible by n_dims
// this check matches the CPU code, but it is likely wrong as well
// I can't understand the Python code, so if you know what to do here, please fix it
// ref: https://github.com/ml-explore/mlx/blob/dc2edc762c797e3b8de50b1dad4dc0a131691033/benchmarks/python/llama_jax_bench.py#L11-L26
if (ncols % n_dims != 0 && ib == ncols/n_dims) {
return;
}
const int i = row*ncols + ib*n_dims + ic/2;
const int i2 = row/p_delta_rows;
float cur_rot = inv_ndims * ic - ib;