diff --git a/examples/llava/qwen2_vl_surgery.py b/examples/llava/qwen2_vl_surgery.py index 0c8bb3ed0..c6d966581 100644 --- a/examples/llava/qwen2_vl_surgery.py +++ b/examples/llava/qwen2_vl_surgery.py @@ -37,7 +37,7 @@ def find_vision_tensors(qwen2vl, dtype) -> Dict[str, np.ndarray]: vision_model = qwen2vl.visual tensor_map = {} for name, ten in vision_model.state_dict().items(): - ten = ten.numpy().astype(dtype) + ten = ten.numpy() if 'qkv' in name: if ten.ndim == 2: # weight c3, _ = ten.shape @@ -68,18 +68,23 @@ def find_vision_tensors(qwen2vl, dtype) -> Dict[str, np.ndarray]: tensor_map["v.patch_embd.weight.1"] = ten[:, :, 1, ...] else: tensor_map[to_gguf_name(f"vision_model.{name}")] = ten - - tensor_map["v.position_embd.weight"] = np.zeros([10, 10], dtype=dtype) # dummy tensor, just here as a placeholder + + for new_name, ten in tensor_map.items(): + if ten.ndim <= 1 or new_name.endswith("_norm.weight"): + tensor_map[new_name] = ten.astype(np.float32) + else: + tensor_map[new_name] = ten.astype(dtype) + tensor_map["v.position_embd.weight"] = np.zeros([10, 10], dtype=np.float32) # dummy tensor, just here as a placeholder return tensor_map -def main(args, data_type='fp32'): - if data_type == 'fp32': +def main(args): + if args.data_type == 'fp32': dtype = torch.float32 np_dtype = np.float32 ftype = 0 - elif data_type == 'fp16': - dtype = torch.float16 + elif args.data_type == 'fp16': + dtype = torch.float32 np_dtype = np.float16 ftype = 1 else: @@ -144,5 +149,6 @@ def main(args, data_type='fp32'): if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("model_name", nargs='?', default="Qwen/Qwen2-VL-2B-Instruct") + parser.add_argument("--data_type", nargs='?', choices=['fp32', 'fp16'], default="fp32") args = parser.parse_args() main(args) \ No newline at end of file diff --git a/ggml/src/ggml-cuda/rope.cu b/ggml/src/ggml-cuda/rope.cu index 13a5d1005..9a66badb9 100644 --- a/ggml/src/ggml-cuda/rope.cu +++ b/ggml/src/ggml-cuda/rope.cu @@ -181,15 +181,6 @@ static __global__ void rope_vision( const int row = blockDim.x*blockIdx.x + threadIdx.x; - // if (i0 >= n_dims) { - // const int i = row*ne0 + i0; - - // dst[i + 0] = x[i + 0]; - // dst[i + 1] = x[i + 1]; - - // return; - // } - const int i = row*ne0 + i0/2; const int i2 = row/p_delta_rows; // i2-th tokens @@ -348,6 +339,14 @@ static void rope_neox_cuda_f32( rope_neox_cuda(x, dst, ne0, n_dims, nr, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream); } +static void rope_mrope_cuda_f16( + const half * x, half * dst, int ne0, int ne2, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows, + float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, mrope_sections sections, cudaStream_t stream +) { + + rope_mrope_cuda(x, dst, ne0, ne2, n_dims, nr, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, stream); +} + static void rope_mrope_cuda_f32( const float * x, float * dst, int ne0, int ne2, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows, float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, mrope_sections sections, cudaStream_t stream @@ -356,6 +355,14 @@ static void rope_mrope_cuda_f32( rope_mrope_cuda(x, dst, ne0, ne2, n_dims, nr, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, stream); } +static void rope_vision_cuda_f16( + const half * x, half * dst, int ne0, int ne2, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows, + float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, mrope_sections sections, cudaStream_t stream +) { + + rope_vision_cuda(x, dst, ne0, ne2, n_dims, nr, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, stream); +} + static void rope_vision_cuda_f32( const float * x, float * dst, int ne0, int ne2, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows, float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, mrope_sections sections, cudaStream_t stream @@ -448,11 +455,11 @@ void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { (const float *)src0_d, (float *)dst_d, ne00, ne02, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, stream ); - } else if (src0->type == GGML_TYPE_F16 && false) { - // rope_mrope_cuda_f16( - // (const half *)src0_d, (half *)dst_d, ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor, - // attn_factor, corr_dims, freq_factors, stream - // ); + } else if (src0->type == GGML_TYPE_F16) { + rope_mrope_cuda_f16( + (const half *)src0_d, (half *)dst_d, ne00, ne02, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor, + attn_factor, corr_dims, freq_factors, sections, stream + ); } else { GGML_ABORT("fatal error"); } @@ -462,11 +469,11 @@ void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { (const float *)src0_d, (float *)dst_d, ne00, ne02, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, stream ); - } else if (src0->type == GGML_TYPE_F16 && false) { - // rope_vision_cuda_f16( - // (const half *)src0_d, (half *)dst_d, ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor, - // attn_factor, corr_dims, freq_factors, stream - // ); + } else if (src0->type == GGML_TYPE_F16) { + rope_vision_cuda_f16( + (const half *)src0_d, (half *)dst_d, ne00, ne02, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor, + attn_factor, corr_dims, freq_factors, sections, stream + ); } else { GGML_ABORT("fatal error"); } diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index e589d6552..7b16637cd 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -11326,7 +11326,7 @@ static void ggml_compute_forward_rope_f32( memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float)); memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float)); memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float)); - memcpy(§ions, (int32_t *) dst->op_params + 11, sizeof(int) * 4); + memcpy(§ions, (int32_t *) dst->op_params + 11, sizeof(int)*4); GGML_TENSOR_UNARY_OP_LOCALS @@ -11463,11 +11463,10 @@ static void ggml_compute_forward_rope_f32( const float x0 = src[0]; const float x1 = src[n_dims]; - dst_data[0] = x0*cos_theta - x1*sin_theta; + dst_data[0] = x0*cos_theta - x1*sin_theta; dst_data[n_dims] = x0*sin_theta + x1*cos_theta; } - } - else { + } else { // fill the remain channels with data from src tensor for (int64_t i0 = n_dims; i0 < ne0; i0 += 2) { const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); @@ -11493,6 +11492,7 @@ static void ggml_compute_forward_rope_f16( const struct ggml_tensor * src2 = dst->src[2]; float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow; + int sections[4]; //const int n_past = ((int32_t *) dst->op_params)[0]; const int n_dims = ((int32_t *) dst->op_params)[1]; @@ -11505,6 +11505,8 @@ static void ggml_compute_forward_rope_f16( memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float)); memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float)); memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float)); + memcpy(§ions, (int32_t *) dst->op_params + 11, sizeof(int)*4); + GGML_TENSOR_UNARY_OP_LOCALS @@ -11537,6 +11539,12 @@ static void ggml_compute_forward_rope_f16( ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims); const bool is_neox = mode & GGML_ROPE_TYPE_NEOX; + const bool is_mrope = sections[0] > 0 || sections[1] > 0 || sections[2] > 0; + const bool is_vision = is_mrope && sections[3] > 0; + + if (is_vision) { + GGML_ASSERT(n_dims == ne0/2); + } const float * freq_factors = NULL; if (src2 != NULL) { @@ -11557,7 +11565,19 @@ static void ggml_compute_forward_rope_f16( const int64_t p = pos[i2]; float * cache = (float *) params->wdata + (ne0 + CACHE_LINE_SIZE_F32)*ith; - ggml_rope_cache_init(p, freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale); + if (!is_mrope) { + const int64_t p = pos[i2]; + ggml_rope_cache_init(p, freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale); + } + else { + const int64_t p_t = pos[i2]; + const int64_t p_h = pos[i2 + ne2]; + const int64_t p_w = pos[i2 + ne2 * 2]; + const int64_t p_e = pos[i2 + ne2 * 3]; + ggml_mrope_cache_init( + p_t, p_h, p_w, p_e, sections, sections[3] != 0, + freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale); + } for (int64_t i1 = 0; i1 < ne1; i1++) { if (ir++ < ir0) continue; @@ -11577,6 +11597,22 @@ static void ggml_compute_forward_rope_f16( dst_data[0] = GGML_FP32_TO_FP16(x0*cos_theta - x1*sin_theta); dst_data[1] = GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta); } + } else if (is_vision){ + for (int64_t i0 = 0; i0 < n_dims; i0 += 2) { + const int64_t ic = i0/2; + + const float cos_theta = cache[i0 + 0]; + const float sin_theta = cache[i0 + 1]; + + const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00); + ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0); + + const float x0 = GGML_FP16_TO_FP32(src[0]); + const float x1 = GGML_FP16_TO_FP32(src[n_dims]); + + dst_data[0] = GGML_FP32_TO_FP16(x0*cos_theta - x1*sin_theta); + dst_data[n_dims] = GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta); + } } else { for (int64_t i0 = 0; i0 < n_dims; i0 += 2) { const int64_t ic = i0/2; @@ -11595,12 +11631,30 @@ static void ggml_compute_forward_rope_f16( } } - for (int64_t i0 = n_dims; i0 < ne0; i0 += 2) { - const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); - ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + if (is_vision) { + for (int64_t i0 = n_dims; i0 < ne0; i0 += 2) { + const int64_t ic = i0/2; - dst_data[0] = src[0]; - dst_data[1] = src[1]; + const float cos_theta = cache[i0 + 0]; + const float sin_theta = cache[i0 + 1]; + + const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00); + ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0); + + const float x0 = GGML_FP16_TO_FP32(src[0]); + const float x1 = GGML_FP16_TO_FP32(src[n_dims]); + + dst_data[0] = GGML_FP32_TO_FP16(x0*cos_theta - x1*sin_theta); + dst_data[n_dims] = GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta); + } + } else { + for (int64_t i0 = n_dims; i0 < ne0; i0 += 2) { + const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); + ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + + dst_data[0] = src[0]; + dst_data[1] = src[1]; + } } } }