add fp16 support for qwen2vl and m-rope
This commit is contained in:
parent
3237bb4614
commit
201f7043c3
3 changed files with 103 additions and 36 deletions
|
@ -37,7 +37,7 @@ def find_vision_tensors(qwen2vl, dtype) -> Dict[str, np.ndarray]:
|
|||
vision_model = qwen2vl.visual
|
||||
tensor_map = {}
|
||||
for name, ten in vision_model.state_dict().items():
|
||||
ten = ten.numpy().astype(dtype)
|
||||
ten = ten.numpy()
|
||||
if 'qkv' in name:
|
||||
if ten.ndim == 2: # weight
|
||||
c3, _ = ten.shape
|
||||
|
@ -69,17 +69,22 @@ def find_vision_tensors(qwen2vl, dtype) -> Dict[str, np.ndarray]:
|
|||
else:
|
||||
tensor_map[to_gguf_name(f"vision_model.{name}")] = ten
|
||||
|
||||
tensor_map["v.position_embd.weight"] = np.zeros([10, 10], dtype=dtype) # dummy tensor, just here as a placeholder
|
||||
for new_name, ten in tensor_map.items():
|
||||
if ten.ndim <= 1 or new_name.endswith("_norm.weight"):
|
||||
tensor_map[new_name] = ten.astype(np.float32)
|
||||
else:
|
||||
tensor_map[new_name] = ten.astype(dtype)
|
||||
tensor_map["v.position_embd.weight"] = np.zeros([10, 10], dtype=np.float32) # dummy tensor, just here as a placeholder
|
||||
return tensor_map
|
||||
|
||||
|
||||
def main(args, data_type='fp32'):
|
||||
if data_type == 'fp32':
|
||||
def main(args):
|
||||
if args.data_type == 'fp32':
|
||||
dtype = torch.float32
|
||||
np_dtype = np.float32
|
||||
ftype = 0
|
||||
elif data_type == 'fp16':
|
||||
dtype = torch.float16
|
||||
elif args.data_type == 'fp16':
|
||||
dtype = torch.float32
|
||||
np_dtype = np.float16
|
||||
ftype = 1
|
||||
else:
|
||||
|
@ -144,5 +149,6 @@ def main(args, data_type='fp32'):
|
|||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("model_name", nargs='?', default="Qwen/Qwen2-VL-2B-Instruct")
|
||||
parser.add_argument("--data_type", nargs='?', choices=['fp32', 'fp16'], default="fp32")
|
||||
args = parser.parse_args()
|
||||
main(args)
|
|
@ -181,15 +181,6 @@ static __global__ void rope_vision(
|
|||
|
||||
const int row = blockDim.x*blockIdx.x + threadIdx.x;
|
||||
|
||||
// if (i0 >= n_dims) {
|
||||
// const int i = row*ne0 + i0;
|
||||
|
||||
// dst[i + 0] = x[i + 0];
|
||||
// dst[i + 1] = x[i + 1];
|
||||
|
||||
// return;
|
||||
// }
|
||||
|
||||
const int i = row*ne0 + i0/2;
|
||||
const int i2 = row/p_delta_rows; // i2-th tokens
|
||||
|
||||
|
@ -348,6 +339,14 @@ static void rope_neox_cuda_f32(
|
|||
rope_neox_cuda<float>(x, dst, ne0, n_dims, nr, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream);
|
||||
}
|
||||
|
||||
static void rope_mrope_cuda_f16(
|
||||
const half * x, half * dst, int ne0, int ne2, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
|
||||
float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, mrope_sections sections, cudaStream_t stream
|
||||
) {
|
||||
|
||||
rope_mrope_cuda<half>(x, dst, ne0, ne2, n_dims, nr, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, stream);
|
||||
}
|
||||
|
||||
static void rope_mrope_cuda_f32(
|
||||
const float * x, float * dst, int ne0, int ne2, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
|
||||
float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, mrope_sections sections, cudaStream_t stream
|
||||
|
@ -356,6 +355,14 @@ static void rope_mrope_cuda_f32(
|
|||
rope_mrope_cuda<float>(x, dst, ne0, ne2, n_dims, nr, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, stream);
|
||||
}
|
||||
|
||||
static void rope_vision_cuda_f16(
|
||||
const half * x, half * dst, int ne0, int ne2, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
|
||||
float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, mrope_sections sections, cudaStream_t stream
|
||||
) {
|
||||
|
||||
rope_vision_cuda<half>(x, dst, ne0, ne2, n_dims, nr, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, stream);
|
||||
}
|
||||
|
||||
static void rope_vision_cuda_f32(
|
||||
const float * x, float * dst, int ne0, int ne2, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
|
||||
float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, mrope_sections sections, cudaStream_t stream
|
||||
|
@ -448,11 +455,11 @@ void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
|||
(const float *)src0_d, (float *)dst_d, ne00, ne02, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
|
||||
attn_factor, corr_dims, freq_factors, sections, stream
|
||||
);
|
||||
} else if (src0->type == GGML_TYPE_F16 && false) {
|
||||
// rope_mrope_cuda_f16(
|
||||
// (const half *)src0_d, (half *)dst_d, ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
|
||||
// attn_factor, corr_dims, freq_factors, stream
|
||||
// );
|
||||
} else if (src0->type == GGML_TYPE_F16) {
|
||||
rope_mrope_cuda_f16(
|
||||
(const half *)src0_d, (half *)dst_d, ne00, ne02, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
|
||||
attn_factor, corr_dims, freq_factors, sections, stream
|
||||
);
|
||||
} else {
|
||||
GGML_ABORT("fatal error");
|
||||
}
|
||||
|
@ -462,11 +469,11 @@ void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
|||
(const float *)src0_d, (float *)dst_d, ne00, ne02, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
|
||||
attn_factor, corr_dims, freq_factors, sections, stream
|
||||
);
|
||||
} else if (src0->type == GGML_TYPE_F16 && false) {
|
||||
// rope_vision_cuda_f16(
|
||||
// (const half *)src0_d, (half *)dst_d, ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
|
||||
// attn_factor, corr_dims, freq_factors, stream
|
||||
// );
|
||||
} else if (src0->type == GGML_TYPE_F16) {
|
||||
rope_vision_cuda_f16(
|
||||
(const half *)src0_d, (half *)dst_d, ne00, ne02, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
|
||||
attn_factor, corr_dims, freq_factors, sections, stream
|
||||
);
|
||||
} else {
|
||||
GGML_ABORT("fatal error");
|
||||
}
|
||||
|
|
|
@ -11326,7 +11326,7 @@ static void ggml_compute_forward_rope_f32(
|
|||
memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float));
|
||||
memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
|
||||
memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
|
||||
memcpy(§ions, (int32_t *) dst->op_params + 11, sizeof(int) * 4);
|
||||
memcpy(§ions, (int32_t *) dst->op_params + 11, sizeof(int)*4);
|
||||
|
||||
GGML_TENSOR_UNARY_OP_LOCALS
|
||||
|
||||
|
@ -11466,8 +11466,7 @@ static void ggml_compute_forward_rope_f32(
|
|||
dst_data[0] = x0*cos_theta - x1*sin_theta;
|
||||
dst_data[n_dims] = x0*sin_theta + x1*cos_theta;
|
||||
}
|
||||
}
|
||||
else {
|
||||
} else {
|
||||
// fill the remain channels with data from src tensor
|
||||
for (int64_t i0 = n_dims; i0 < ne0; i0 += 2) {
|
||||
const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
|
||||
|
@ -11493,6 +11492,7 @@ static void ggml_compute_forward_rope_f16(
|
|||
const struct ggml_tensor * src2 = dst->src[2];
|
||||
|
||||
float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
|
||||
int sections[4];
|
||||
|
||||
//const int n_past = ((int32_t *) dst->op_params)[0];
|
||||
const int n_dims = ((int32_t *) dst->op_params)[1];
|
||||
|
@ -11505,6 +11505,8 @@ static void ggml_compute_forward_rope_f16(
|
|||
memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float));
|
||||
memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
|
||||
memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
|
||||
memcpy(§ions, (int32_t *) dst->op_params + 11, sizeof(int)*4);
|
||||
|
||||
|
||||
GGML_TENSOR_UNARY_OP_LOCALS
|
||||
|
||||
|
@ -11537,6 +11539,12 @@ static void ggml_compute_forward_rope_f16(
|
|||
ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims);
|
||||
|
||||
const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
|
||||
const bool is_mrope = sections[0] > 0 || sections[1] > 0 || sections[2] > 0;
|
||||
const bool is_vision = is_mrope && sections[3] > 0;
|
||||
|
||||
if (is_vision) {
|
||||
GGML_ASSERT(n_dims == ne0/2);
|
||||
}
|
||||
|
||||
const float * freq_factors = NULL;
|
||||
if (src2 != NULL) {
|
||||
|
@ -11557,7 +11565,19 @@ static void ggml_compute_forward_rope_f16(
|
|||
const int64_t p = pos[i2];
|
||||
|
||||
float * cache = (float *) params->wdata + (ne0 + CACHE_LINE_SIZE_F32)*ith;
|
||||
if (!is_mrope) {
|
||||
const int64_t p = pos[i2];
|
||||
ggml_rope_cache_init(p, freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
|
||||
}
|
||||
else {
|
||||
const int64_t p_t = pos[i2];
|
||||
const int64_t p_h = pos[i2 + ne2];
|
||||
const int64_t p_w = pos[i2 + ne2 * 2];
|
||||
const int64_t p_e = pos[i2 + ne2 * 3];
|
||||
ggml_mrope_cache_init(
|
||||
p_t, p_h, p_w, p_e, sections, sections[3] != 0,
|
||||
freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
|
||||
}
|
||||
|
||||
for (int64_t i1 = 0; i1 < ne1; i1++) {
|
||||
if (ir++ < ir0) continue;
|
||||
|
@ -11577,6 +11597,22 @@ static void ggml_compute_forward_rope_f16(
|
|||
dst_data[0] = GGML_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
|
||||
dst_data[1] = GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
|
||||
}
|
||||
} else if (is_vision){
|
||||
for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
|
||||
const int64_t ic = i0/2;
|
||||
|
||||
const float cos_theta = cache[i0 + 0];
|
||||
const float sin_theta = cache[i0 + 1];
|
||||
|
||||
const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
|
||||
ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
|
||||
|
||||
const float x0 = GGML_FP16_TO_FP32(src[0]);
|
||||
const float x1 = GGML_FP16_TO_FP32(src[n_dims]);
|
||||
|
||||
dst_data[0] = GGML_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
|
||||
dst_data[n_dims] = GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
|
||||
}
|
||||
} else {
|
||||
for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
|
||||
const int64_t ic = i0/2;
|
||||
|
@ -11595,6 +11631,23 @@ static void ggml_compute_forward_rope_f16(
|
|||
}
|
||||
}
|
||||
|
||||
if (is_vision) {
|
||||
for (int64_t i0 = n_dims; i0 < ne0; i0 += 2) {
|
||||
const int64_t ic = i0/2;
|
||||
|
||||
const float cos_theta = cache[i0 + 0];
|
||||
const float sin_theta = cache[i0 + 1];
|
||||
|
||||
const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
|
||||
ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
|
||||
|
||||
const float x0 = GGML_FP16_TO_FP32(src[0]);
|
||||
const float x1 = GGML_FP16_TO_FP32(src[n_dims]);
|
||||
|
||||
dst_data[0] = GGML_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
|
||||
dst_data[n_dims] = GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
|
||||
}
|
||||
} else {
|
||||
for (int64_t i0 = n_dims; i0 < ne0; i0 += 2) {
|
||||
const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
|
||||
ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
||||
|
@ -11605,6 +11658,7 @@ static void ggml_compute_forward_rope_f16(
|
|||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static void ggml_compute_forward_rope(
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue