Add assert in ggml_cuda_op_pool2d
This commit is contained in:
parent
0d94da7cbb
commit
ca4ec6d867
1 changed files with 3 additions and 0 deletions
|
@ -8723,6 +8723,9 @@ static void ggml_cuda_op_pool2d(
|
||||||
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
|
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
|
||||||
const float * src0_dd, const float * src1_dd, float * dst_dd, cudaStream_t main_stream) {
|
const float * src0_dd, const float * src1_dd, float * dst_dd, cudaStream_t main_stream) {
|
||||||
|
|
||||||
|
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||||
|
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
||||||
|
|
||||||
const int32_t * opts = (const int32_t *)dst->op_params;
|
const int32_t * opts = (const int32_t *)dst->op_params;
|
||||||
enum ggml_op_pool op = static_cast<ggml_op_pool>(opts[0]);
|
enum ggml_op_pool op = static_cast<ggml_op_pool>(opts[0]);
|
||||||
const int k0 = opts[1];
|
const int k0 = opts[1];
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue