fix build warning

This commit is contained in:
MengqingCao 2024-08-02 07:43:48 +00:00
parent c9f65bfa76
commit c1111cc096

View file

@ -1313,10 +1313,11 @@ aclnnStatus aclnnIm2col(void* workspace, uint64_t workspaceSize,
} }
#endif #endif
void ggml_cann_im2col_2d_post_process(ggml_backend_cann_context& ctx, static void ggml_cann_im2col_2d_post_process(ggml_backend_cann_context& ctx,
ggml_tensor* dst, ggml_tensor* src1, ggml_tensor* dst,
aclTensor* tmp_cast_tensor, ggml_tensor* src1,
aclTensor* tmp_im2col_tensor) { aclTensor* tmp_cast_tensor,
aclTensor* tmp_im2col_tensor) {
// Permute: [N, IC * KH * KW, OW * OH] -> [N, OW * OH, IC * KH * KW] // Permute: [N, IC * KH * KW, OW * OH] -> [N, OW * OH, IC * KH * KW]
int64_t dst_ne[] = {dst->ne[0], dst->ne[1] * dst->ne[2], dst->ne[3]}; int64_t dst_ne[] = {dst->ne[0], dst->ne[1] * dst->ne[2], dst->ne[3]};
size_t dst_nb[] = {dst->nb[0], dst->nb[1], dst->nb[3]}; size_t dst_nb[] = {dst->nb[0], dst->nb[1], dst->nb[3]};
@ -1334,7 +1335,7 @@ void ggml_cann_im2col_2d_post_process(ggml_backend_cann_context& ctx,
ACL_CHECK(aclDestroyTensor(acl_dst)); ACL_CHECK(aclDestroyTensor(acl_dst));
} }
void ggml_cann_im2col_1d_post_process( static void ggml_cann_im2col_1d_post_process(
ggml_backend_cann_context& ctx, ggml_tensor* dst, ggml_tensor* src1, ggml_backend_cann_context& ctx, ggml_tensor* dst, ggml_tensor* src1,
aclTensor* tmp_cast_tensor, aclTensor* tmp_im2col_tensor, aclTensor* tmp_cast_tensor, aclTensor* tmp_im2col_tensor,
const std::vector<int64_t>& im2col_op_params) { const std::vector<int64_t>& im2col_op_params) {
@ -1389,24 +1390,26 @@ void ggml_cann_im2col_1d_post_process(
size_t size_cpy = KH * KW * ggml_type_size(dst->type); size_t size_cpy = KH * KW * ggml_type_size(dst->type);
for (int c = 0; c < IC; c++) { for (int c = 0; c < IC; c++) {
cur_permute_buffer = tmp_permute_buffer + offset + cur_permute_buffer = (char*)tmp_permute_buffer + offset +
KH * KW * c * ggml_type_size(dst->type); KH * KW * c * ggml_type_size(dst->type);
cur_dst_buffer = cur_dst_buffer = (char*)dst->data +
dst->data + c * KH * KW * n_step_w * ggml_type_size(dst->type); c * KH * KW * n_step_w * ggml_type_size(dst->type);
for (int i = 0; i < n_step_w; i++) { for (int i = 0; i < n_step_w; i++) {
ACL_CHECK(aclrtMemcpyAsync( ACL_CHECK(aclrtMemcpyAsync(
cur_dst_buffer, size_cpy, cur_permute_buffer, size_cpy, cur_dst_buffer, size_cpy, cur_permute_buffer, size_cpy,
ACL_MEMCPY_DEVICE_TO_DEVICE, ctx.stream())); ACL_MEMCPY_DEVICE_TO_DEVICE, ctx.stream()));
cur_dst_buffer += KH * KW * ggml_type_size(dst->type); cur_dst_buffer =
cur_permute_buffer += KH * KW * IC * ggml_type_size(dst->type); (char*)cur_dst_buffer + KH * KW * ggml_type_size(dst->type);
cur_permute_buffer = (char*)cur_permute_buffer +
KH * KW * IC * ggml_type_size(dst->type);
} }
} }
} else { } else {
offset = KH * KW * n_step_w * offset = KH * KW * n_step_w *
ggml_type_size(dst->type); // equal to ggml_nbytes(dst) ggml_type_size(dst->type); // equal to ggml_nbytes(dst)
ACL_CHECK(aclrtMemcpyAsync(dst->data, offset, ACL_CHECK(aclrtMemcpyAsync(dst->data, offset,
tmp_permute_buffer + offset, offset, (char*)tmp_permute_buffer + offset, offset,
ACL_MEMCPY_DEVICE_TO_DEVICE, ctx.stream())); ACL_MEMCPY_DEVICE_TO_DEVICE, ctx.stream()));
} }