fix build warning

This commit is contained in:
MengqingCao 2024-08-02 07:43:48 +00:00
parent c9f65bfa76
commit c1111cc096

View file

@ -1313,10 +1313,11 @@ aclnnStatus aclnnIm2col(void* workspace, uint64_t workspaceSize,
}
#endif
void ggml_cann_im2col_2d_post_process(ggml_backend_cann_context& ctx,
ggml_tensor* dst, ggml_tensor* src1,
aclTensor* tmp_cast_tensor,
aclTensor* tmp_im2col_tensor) {
static void ggml_cann_im2col_2d_post_process(ggml_backend_cann_context& ctx,
ggml_tensor* dst,
ggml_tensor* src1,
aclTensor* tmp_cast_tensor,
aclTensor* tmp_im2col_tensor) {
// Permute: [N, IC * KH * KW, OW * OH] -> [N, OW * OH, IC * KH * KW]
int64_t dst_ne[] = {dst->ne[0], dst->ne[1] * dst->ne[2], dst->ne[3]};
size_t dst_nb[] = {dst->nb[0], dst->nb[1], dst->nb[3]};
@ -1334,7 +1335,7 @@ void ggml_cann_im2col_2d_post_process(ggml_backend_cann_context& ctx,
ACL_CHECK(aclDestroyTensor(acl_dst));
}
void ggml_cann_im2col_1d_post_process(
static void ggml_cann_im2col_1d_post_process(
ggml_backend_cann_context& ctx, ggml_tensor* dst, ggml_tensor* src1,
aclTensor* tmp_cast_tensor, aclTensor* tmp_im2col_tensor,
const std::vector<int64_t>& im2col_op_params) {
@ -1389,24 +1390,26 @@ void ggml_cann_im2col_1d_post_process(
size_t size_cpy = KH * KW * ggml_type_size(dst->type);
for (int c = 0; c < IC; c++) {
cur_permute_buffer = tmp_permute_buffer + offset +
cur_permute_buffer = (char*)tmp_permute_buffer + offset +
KH * KW * c * ggml_type_size(dst->type);
cur_dst_buffer =
dst->data + c * KH * KW * n_step_w * ggml_type_size(dst->type);
cur_dst_buffer = (char*)dst->data +
c * KH * KW * n_step_w * ggml_type_size(dst->type);
for (int i = 0; i < n_step_w; i++) {
ACL_CHECK(aclrtMemcpyAsync(
cur_dst_buffer, size_cpy, cur_permute_buffer, size_cpy,
ACL_MEMCPY_DEVICE_TO_DEVICE, ctx.stream()));
cur_dst_buffer += KH * KW * ggml_type_size(dst->type);
cur_permute_buffer += KH * KW * IC * ggml_type_size(dst->type);
cur_dst_buffer =
(char*)cur_dst_buffer + KH * KW * ggml_type_size(dst->type);
cur_permute_buffer = (char*)cur_permute_buffer +
KH * KW * IC * ggml_type_size(dst->type);
}
}
} else {
offset = KH * KW * n_step_w *
ggml_type_size(dst->type); // equal to ggml_nbytes(dst)
ACL_CHECK(aclrtMemcpyAsync(dst->data, offset,
tmp_permute_buffer + offset, offset,
(char*)tmp_permute_buffer + offset, offset,
ACL_MEMCPY_DEVICE_TO_DEVICE, ctx.stream()));
}