fix build warning
This commit is contained in:
parent
c9f65bfa76
commit
c1111cc096
1 changed files with 14 additions and 11 deletions
|
@ -1313,10 +1313,11 @@ aclnnStatus aclnnIm2col(void* workspace, uint64_t workspaceSize,
|
|||
}
|
||||
#endif
|
||||
|
||||
void ggml_cann_im2col_2d_post_process(ggml_backend_cann_context& ctx,
|
||||
ggml_tensor* dst, ggml_tensor* src1,
|
||||
aclTensor* tmp_cast_tensor,
|
||||
aclTensor* tmp_im2col_tensor) {
|
||||
static void ggml_cann_im2col_2d_post_process(ggml_backend_cann_context& ctx,
|
||||
ggml_tensor* dst,
|
||||
ggml_tensor* src1,
|
||||
aclTensor* tmp_cast_tensor,
|
||||
aclTensor* tmp_im2col_tensor) {
|
||||
// Permute: [N, IC * KH * KW, OW * OH] -> [N, OW * OH, IC * KH * KW]
|
||||
int64_t dst_ne[] = {dst->ne[0], dst->ne[1] * dst->ne[2], dst->ne[3]};
|
||||
size_t dst_nb[] = {dst->nb[0], dst->nb[1], dst->nb[3]};
|
||||
|
@ -1334,7 +1335,7 @@ void ggml_cann_im2col_2d_post_process(ggml_backend_cann_context& ctx,
|
|||
ACL_CHECK(aclDestroyTensor(acl_dst));
|
||||
}
|
||||
|
||||
void ggml_cann_im2col_1d_post_process(
|
||||
static void ggml_cann_im2col_1d_post_process(
|
||||
ggml_backend_cann_context& ctx, ggml_tensor* dst, ggml_tensor* src1,
|
||||
aclTensor* tmp_cast_tensor, aclTensor* tmp_im2col_tensor,
|
||||
const std::vector<int64_t>& im2col_op_params) {
|
||||
|
@ -1389,24 +1390,26 @@ void ggml_cann_im2col_1d_post_process(
|
|||
size_t size_cpy = KH * KW * ggml_type_size(dst->type);
|
||||
|
||||
for (int c = 0; c < IC; c++) {
|
||||
cur_permute_buffer = tmp_permute_buffer + offset +
|
||||
cur_permute_buffer = (char*)tmp_permute_buffer + offset +
|
||||
KH * KW * c * ggml_type_size(dst->type);
|
||||
cur_dst_buffer =
|
||||
dst->data + c * KH * KW * n_step_w * ggml_type_size(dst->type);
|
||||
cur_dst_buffer = (char*)dst->data +
|
||||
c * KH * KW * n_step_w * ggml_type_size(dst->type);
|
||||
|
||||
for (int i = 0; i < n_step_w; i++) {
|
||||
ACL_CHECK(aclrtMemcpyAsync(
|
||||
cur_dst_buffer, size_cpy, cur_permute_buffer, size_cpy,
|
||||
ACL_MEMCPY_DEVICE_TO_DEVICE, ctx.stream()));
|
||||
cur_dst_buffer += KH * KW * ggml_type_size(dst->type);
|
||||
cur_permute_buffer += KH * KW * IC * ggml_type_size(dst->type);
|
||||
cur_dst_buffer =
|
||||
(char*)cur_dst_buffer + KH * KW * ggml_type_size(dst->type);
|
||||
cur_permute_buffer = (char*)cur_permute_buffer +
|
||||
KH * KW * IC * ggml_type_size(dst->type);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
offset = KH * KW * n_step_w *
|
||||
ggml_type_size(dst->type); // equal to ggml_nbytes(dst)
|
||||
ACL_CHECK(aclrtMemcpyAsync(dst->data, offset,
|
||||
tmp_permute_buffer + offset, offset,
|
||||
(char*)tmp_permute_buffer + offset, offset,
|
||||
ACL_MEMCPY_DEVICE_TO_DEVICE, ctx.stream()));
|
||||
}
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue