phi-2 : various fixes
This commit is contained in:
parent
e20765534d
commit
a2a3d2c8d7
5 changed files with 66 additions and 37 deletions
11
ggml-cuda.cu
11
ggml-cuda.cu
|
@ -4998,7 +4998,16 @@ static __global__ void rope_neox(
|
|||
const int ib = col / n_dims;
|
||||
const int ic = col % n_dims;
|
||||
|
||||
const int i = row*ncols + ib*n_dims + ic/2;
|
||||
// IMPORTANT: consider the case ncols == 80 and n_dims == 32 (phi-2)
|
||||
// I don't know what we are supposed to compute, because the row is not divisible by n_dims
|
||||
// this check matches the CPU code, but it is likely wrong as well
|
||||
// I can't understand the Python code, so if you know what to do here, please fix it
|
||||
// ref: https://github.com/ml-explore/mlx/blob/dc2edc762c797e3b8de50b1dad4dc0a131691033/benchmarks/python/llama_jax_bench.py#L11-L26
|
||||
if (ncols % n_dims != 0 && ib == ncols/n_dims) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int i = row*ncols + ib*n_dims + ic/2;
|
||||
const int i2 = row/p_delta_rows;
|
||||
|
||||
float cur_rot = inv_ndims * ic - ib;
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue