fix build warning
This commit is contained in:
parent
c9f65bfa76
commit
c1111cc096
1 changed files with 14 additions and 11 deletions
|
@ -1313,10 +1313,11 @@ aclnnStatus aclnnIm2col(void* workspace, uint64_t workspaceSize,
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
void ggml_cann_im2col_2d_post_process(ggml_backend_cann_context& ctx,
|
static void ggml_cann_im2col_2d_post_process(ggml_backend_cann_context& ctx,
|
||||||
ggml_tensor* dst, ggml_tensor* src1,
|
ggml_tensor* dst,
|
||||||
aclTensor* tmp_cast_tensor,
|
ggml_tensor* src1,
|
||||||
aclTensor* tmp_im2col_tensor) {
|
aclTensor* tmp_cast_tensor,
|
||||||
|
aclTensor* tmp_im2col_tensor) {
|
||||||
// Permute: [N, IC * KH * KW, OW * OH] -> [N, OW * OH, IC * KH * KW]
|
// Permute: [N, IC * KH * KW, OW * OH] -> [N, OW * OH, IC * KH * KW]
|
||||||
int64_t dst_ne[] = {dst->ne[0], dst->ne[1] * dst->ne[2], dst->ne[3]};
|
int64_t dst_ne[] = {dst->ne[0], dst->ne[1] * dst->ne[2], dst->ne[3]};
|
||||||
size_t dst_nb[] = {dst->nb[0], dst->nb[1], dst->nb[3]};
|
size_t dst_nb[] = {dst->nb[0], dst->nb[1], dst->nb[3]};
|
||||||
|
@ -1334,7 +1335,7 @@ void ggml_cann_im2col_2d_post_process(ggml_backend_cann_context& ctx,
|
||||||
ACL_CHECK(aclDestroyTensor(acl_dst));
|
ACL_CHECK(aclDestroyTensor(acl_dst));
|
||||||
}
|
}
|
||||||
|
|
||||||
void ggml_cann_im2col_1d_post_process(
|
static void ggml_cann_im2col_1d_post_process(
|
||||||
ggml_backend_cann_context& ctx, ggml_tensor* dst, ggml_tensor* src1,
|
ggml_backend_cann_context& ctx, ggml_tensor* dst, ggml_tensor* src1,
|
||||||
aclTensor* tmp_cast_tensor, aclTensor* tmp_im2col_tensor,
|
aclTensor* tmp_cast_tensor, aclTensor* tmp_im2col_tensor,
|
||||||
const std::vector<int64_t>& im2col_op_params) {
|
const std::vector<int64_t>& im2col_op_params) {
|
||||||
|
@ -1389,24 +1390,26 @@ void ggml_cann_im2col_1d_post_process(
|
||||||
size_t size_cpy = KH * KW * ggml_type_size(dst->type);
|
size_t size_cpy = KH * KW * ggml_type_size(dst->type);
|
||||||
|
|
||||||
for (int c = 0; c < IC; c++) {
|
for (int c = 0; c < IC; c++) {
|
||||||
cur_permute_buffer = tmp_permute_buffer + offset +
|
cur_permute_buffer = (char*)tmp_permute_buffer + offset +
|
||||||
KH * KW * c * ggml_type_size(dst->type);
|
KH * KW * c * ggml_type_size(dst->type);
|
||||||
cur_dst_buffer =
|
cur_dst_buffer = (char*)dst->data +
|
||||||
dst->data + c * KH * KW * n_step_w * ggml_type_size(dst->type);
|
c * KH * KW * n_step_w * ggml_type_size(dst->type);
|
||||||
|
|
||||||
for (int i = 0; i < n_step_w; i++) {
|
for (int i = 0; i < n_step_w; i++) {
|
||||||
ACL_CHECK(aclrtMemcpyAsync(
|
ACL_CHECK(aclrtMemcpyAsync(
|
||||||
cur_dst_buffer, size_cpy, cur_permute_buffer, size_cpy,
|
cur_dst_buffer, size_cpy, cur_permute_buffer, size_cpy,
|
||||||
ACL_MEMCPY_DEVICE_TO_DEVICE, ctx.stream()));
|
ACL_MEMCPY_DEVICE_TO_DEVICE, ctx.stream()));
|
||||||
cur_dst_buffer += KH * KW * ggml_type_size(dst->type);
|
cur_dst_buffer =
|
||||||
cur_permute_buffer += KH * KW * IC * ggml_type_size(dst->type);
|
(char*)cur_dst_buffer + KH * KW * ggml_type_size(dst->type);
|
||||||
|
cur_permute_buffer = (char*)cur_permute_buffer +
|
||||||
|
KH * KW * IC * ggml_type_size(dst->type);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
offset = KH * KW * n_step_w *
|
offset = KH * KW * n_step_w *
|
||||||
ggml_type_size(dst->type); // equal to ggml_nbytes(dst)
|
ggml_type_size(dst->type); // equal to ggml_nbytes(dst)
|
||||||
ACL_CHECK(aclrtMemcpyAsync(dst->data, offset,
|
ACL_CHECK(aclrtMemcpyAsync(dst->data, offset,
|
||||||
tmp_permute_buffer + offset, offset,
|
(char*)tmp_permute_buffer + offset, offset,
|
||||||
ACL_MEMCPY_DEVICE_TO_DEVICE, ctx.stream()));
|
ACL_MEMCPY_DEVICE_TO_DEVICE, ctx.stream()));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue