sycl: move downsample global_range into common
Signed-off-by: zhentaoyu <zhentao.yu@intel.com>
This commit is contained in:
parent
df3f1c1850
commit
8bd46e8450
4 changed files with 15 additions and 14 deletions
|
@ -51,3 +51,14 @@ void ggml_sycl_host_free(void* ptr) try {
|
||||||
<< ", line:" << __LINE__ << std::endl;
|
<< ", line:" << __LINE__ << std::endl;
|
||||||
std::exit(1);
|
std::exit(1);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
int64_t downsample_sycl_global_range(int64_t accumulate_block_num, int64_t block_size) {
|
||||||
|
const int64_t max_range = std::numeric_limits<int>::max();
|
||||||
|
int64_t sycl_down_blk_size = block_size;
|
||||||
|
int64_t global_range = accumulate_block_num * sycl_down_blk_size;
|
||||||
|
while(global_range > max_range) {
|
||||||
|
sycl_down_blk_size /= 2;
|
||||||
|
global_range = accumulate_block_num * sycl_down_blk_size;
|
||||||
|
}
|
||||||
|
return sycl_down_blk_size;
|
||||||
|
}
|
||||||
|
|
|
@ -352,4 +352,6 @@ static __dpct_inline__ Tp* get_pointer(sycl::local_accessor<Tp, dim> acc) {
|
||||||
return acc.template get_multi_ptr<sycl::access::decorated::no>().get();
|
return acc.template get_multi_ptr<sycl::access::decorated::no>().get();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
int64_t downsample_sycl_global_range(int64_t accumulate_block_num, int64_t block_size);
|
||||||
|
|
||||||
#endif // GGML_SYCL_COMMON_HPP
|
#endif // GGML_SYCL_COMMON_HPP
|
||||||
|
|
|
@ -437,13 +437,7 @@ static void convert_unary_sycl(const void *__restrict__ vx,
|
||||||
const int64_t num_blocks = (k + SYCL_DEQUANTIZE_BLOCK_SIZE - 1) / SYCL_DEQUANTIZE_BLOCK_SIZE;
|
const int64_t num_blocks = (k + SYCL_DEQUANTIZE_BLOCK_SIZE - 1) / SYCL_DEQUANTIZE_BLOCK_SIZE;
|
||||||
|
|
||||||
// decrease global range when it exceeds the max int
|
// decrease global range when it exceeds the max int
|
||||||
int local_size = SYCL_DEQUANTIZE_BLOCK_SIZE;
|
int64_t local_size = downsample_sycl_global_range(num_blocks, SYCL_DEQUANTIZE_BLOCK_SIZE);
|
||||||
const int64_t max_range = std::numeric_limits<int>::max();
|
|
||||||
int64_t global_range = num_blocks * local_size;
|
|
||||||
while(global_range > max_range) {
|
|
||||||
local_size /= 2;
|
|
||||||
global_range = num_blocks * local_size;
|
|
||||||
}
|
|
||||||
sycl::range<3> block_nums(1, 1, num_blocks);
|
sycl::range<3> block_nums(1, 1, num_blocks);
|
||||||
sycl::range<3> local_range(1, 1, local_size);
|
sycl::range<3> local_range(1, 1, local_size);
|
||||||
{
|
{
|
||||||
|
|
|
@ -64,13 +64,7 @@ static void im2col_sycl(
|
||||||
const int64_t num_blocks = (parallel_elements + SYCL_IM2COL_BLOCK_SIZE - 1) / SYCL_IM2COL_BLOCK_SIZE;
|
const int64_t num_blocks = (parallel_elements + SYCL_IM2COL_BLOCK_SIZE - 1) / SYCL_IM2COL_BLOCK_SIZE;
|
||||||
|
|
||||||
// decrease global range when it exceeds the max int
|
// decrease global range when it exceeds the max int
|
||||||
int local_size = SYCL_IM2COL_BLOCK_SIZE;
|
int64_t local_size = downsample_sycl_global_range(batch * IC * OH * num_blocks, SYCL_IM2COL_BLOCK_SIZE);
|
||||||
const int64_t max_range = std::numeric_limits<int>::max();
|
|
||||||
int64_t global_range = batch * IC * OH * num_blocks * local_size;
|
|
||||||
while(global_range > max_range) {
|
|
||||||
local_size /= 2;
|
|
||||||
global_range = batch * IC * OH * num_blocks * local_size;
|
|
||||||
}
|
|
||||||
sycl::range<3> block_nums(batch * IC, OH, num_blocks);
|
sycl::range<3> block_nums(batch * IC, OH, num_blocks);
|
||||||
sycl::range<3> local_range(1, 1, local_size);
|
sycl::range<3> local_range(1, 1, local_size);
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue