fix pot. int overflows
This commit is contained in:
parent
1120d94b60
commit
066f6cf3e1
3 changed files with 24 additions and 24 deletions
|
@ -10,8 +10,8 @@ static __global__ void cross_entropy_loss_f32(
|
|||
const float * __restrict__ logits, const float * __restrict__ labels, float * __restrict__ dst, const int nclasses, const int k) {
|
||||
extern __shared__ float tmp[];
|
||||
|
||||
logits += blockIdx.x*nclasses;
|
||||
labels += blockIdx.x*nclasses;
|
||||
logits += int64_t(blockIdx.x)*nclasses;
|
||||
labels += int64_t(blockIdx.x)*nclasses;
|
||||
|
||||
// Find maximum for softmax:
|
||||
float max_logit = -INFINITY;
|
||||
|
@ -55,9 +55,9 @@ static __global__ void cross_entropy_loss_back_f32(
|
|||
float * __restrict__ dst, const int nclasses) {
|
||||
extern __shared__ float tmp[];
|
||||
|
||||
logits += blockIdx.x*nclasses;
|
||||
labels += blockIdx.x*nclasses;
|
||||
dst += blockIdx.x*nclasses;
|
||||
logits += int64_t(blockIdx.x)*nclasses;
|
||||
labels += int64_t(blockIdx.x)*nclasses;
|
||||
dst += int64_t(blockIdx.x)*nclasses;
|
||||
|
||||
float maxval = -INFINITY;
|
||||
for (int i = threadIdx.x; i < nclasses; i += WARP_SIZE) {
|
||||
|
@ -115,10 +115,10 @@ void ggml_cuda_cross_entropy_loss(ggml_backend_cuda_context & ctx, ggml_tensor *
|
|||
|
||||
const dim3 blocks_dim(WARP_SIZE, 1, 1);
|
||||
const dim3 blocks_num(nrows, 1, 1);
|
||||
const int nbytes_shared = ne00*sizeof(float);
|
||||
const size_t nbytes_shared = ne00*sizeof(float);
|
||||
|
||||
const int id = ggml_cuda_get_device();
|
||||
const int smpbo = ggml_cuda_info().devices[id].smpbo;
|
||||
const int id = ggml_cuda_get_device();
|
||||
const size_t smpbo = ggml_cuda_info().devices[id].smpbo;
|
||||
|
||||
ggml_cuda_pool_alloc<float> dst_tmp(pool, blocks_num.x);
|
||||
|
||||
|
@ -169,10 +169,10 @@ void ggml_cuda_cross_entropy_loss_back(ggml_backend_cuda_context & ctx, ggml_ten
|
|||
|
||||
const dim3 blocks_dim(WARP_SIZE, 1, 1);
|
||||
const dim3 blocks_num(nrows, 1, 1);
|
||||
const int nbytes_shared = ne00*sizeof(float);
|
||||
const size_t nbytes_shared = ne00*sizeof(float);
|
||||
|
||||
const int id = ggml_cuda_get_device();
|
||||
const int smpbo = ggml_cuda_info().devices[id].smpbo;
|
||||
const int id = ggml_cuda_get_device();
|
||||
const size_t smpbo = ggml_cuda_info().devices[id].smpbo;
|
||||
|
||||
if (nbytes_shared <= smpbo) {
|
||||
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
|
||||
|
|
|
@ -5,8 +5,8 @@ static __global__ void norm_f32(const float * x, float * dst, const int ncols, c
|
|||
const int row = blockIdx.x*blockDim.y + threadIdx.y;
|
||||
const int tid = threadIdx.x;
|
||||
|
||||
x += row*ncols;
|
||||
dst += row*ncols;
|
||||
x += int64_t(row)*ncols;
|
||||
dst += int64_t(row)*ncols;
|
||||
|
||||
float2 mean_var = make_float2(0.0f, 0.0f);
|
||||
|
||||
|
@ -101,8 +101,8 @@ static __global__ void rms_norm_f32(const float * x, float * dst, const int ncol
|
|||
const int row = blockIdx.x*blockDim.y + threadIdx.y;
|
||||
const int tid = threadIdx.x;
|
||||
|
||||
x += row*ncols;
|
||||
dst += row*ncols;
|
||||
x += int64_t(row)*ncols;
|
||||
dst += int64_t(row)*ncols;
|
||||
|
||||
float tmp = 0.0f; // partial sum for thread in warp
|
||||
|
||||
|
@ -140,9 +140,9 @@ static __global__ void rms_norm_back_f32(
|
|||
const int row = blockIdx.x*blockDim.y + threadIdx.y;
|
||||
const int tid = threadIdx.x;
|
||||
|
||||
grad += row*ncols;
|
||||
xf += row*ncols;
|
||||
dst += row*ncols;
|
||||
grad += int64_t(row)*ncols;
|
||||
xf += int64_t(row)*ncols;
|
||||
dst += int64_t(row)*ncols;
|
||||
|
||||
float sum_xx = 0.0f; // sum for squares of x, equivalent to forward pass
|
||||
float sum_xg = 0.0f; // sum for x * gradient, needed because RMS norm mixes inputs
|
||||
|
|
|
@ -23,9 +23,9 @@ static __global__ void soft_max_f32(
|
|||
const int rowx = blockIdx.x;
|
||||
const int rowy = rowx % nrows_y; // broadcast the mask in the row dimension
|
||||
|
||||
x += rowx*ncols;
|
||||
mask += rowy*ncols * (mask != nullptr);
|
||||
dst += rowx*ncols;
|
||||
x += int64_t(rowx)*ncols;
|
||||
mask += int64_t(rowy)*ncols * (mask != nullptr);
|
||||
dst += int64_t(rowx)*ncols;
|
||||
|
||||
const int block_size = block_size_template == 0 ? blockDim.x : block_size_template;
|
||||
|
||||
|
@ -124,9 +124,9 @@ static __global__ void soft_max_back_f32(
|
|||
const int tid = threadIdx.x;
|
||||
const int rowx = blockIdx.x;
|
||||
|
||||
grad += rowx*ncols;
|
||||
dstf += rowx*ncols;
|
||||
dst += rowx*ncols;
|
||||
grad += int64_t(rowx)*ncols;
|
||||
dstf += int64_t(rowx)*ncols;
|
||||
dst += int64_t(rowx)*ncols;
|
||||
|
||||
float dgf_dot = 0.0f; // dot product of dst from forward pass and gradients
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue