#pragma once #include #ifdef OLD_GENERATOR_PATH #include #else #include #endif #include #include namespace { constexpr int UNROLL = 4; } // namespace template __global__ void apex_fused_dropout_kernel(scalar_t const *inputs, scalar_t *outputs, uint8_t *mask, IndexType totalElements, accscalar_t p, std::pair seeds) { accscalar_t pinv = accscalar_t(1) / p; IndexType idx = blockIdx.x * blockDim.x + threadIdx.x; curandStatePhilox4_32_10_t state; curand_init(seeds.first, idx, seeds.second, &state); IndexType rounded_size = ((totalElements - 1) / (blockDim.x * gridDim.x * UNROLL) + 1) * blockDim.x * gridDim.x * UNROLL; for (IndexType linearIndex = idx; linearIndex < rounded_size; linearIndex += gridDim.x * blockDim.x * UNROLL) { float4 rand = curand_uniform4(&state); scalar_t src[UNROLL]; rand.x = rand.x <= p; rand.y = rand.y <= p; rand.z = rand.z <= p; rand.w = rand.w <= p; for (int ii = 0; ii < UNROLL; ii++) { IndexType li = linearIndex + blockDim.x * gridDim.x * ii; if (li < totalElements) { src[ii] = inputs[li]; } } for (int ii = 0; ii < UNROLL; ii++) { IndexType li = linearIndex + blockDim.x * gridDim.x * ii; if (li < totalElements) { outputs[li] = src[ii] * (&rand.x)[ii] * pinv; mask[li] = (uint8_t)(&rand.x)[ii]; } } __syncthreads(); } } template __global__ void apex_dropout_add_kernel(scalar_t const *inputs, scalar_t const *add_inputs, scalar_t *outputs, uint8_t *mask, IndexType totalElements, accscalar_t p, std::pair seeds) { accscalar_t pinv = accscalar_t(1) / p; IndexType idx = blockIdx.x * blockDim.x + threadIdx.x; curandStatePhilox4_32_10_t state; curand_init(seeds.first, idx, seeds.second, &state); IndexType rounded_size = ((totalElements - 1) / (blockDim.x * gridDim.x * UNROLL) + 1) * blockDim.x * gridDim.x * UNROLL; for (IndexType linearIndex = idx; linearIndex < rounded_size; linearIndex += gridDim.x * blockDim.x * UNROLL) { float4 rand = curand_uniform4(&state); scalar_t src[UNROLL]; scalar_t add_src[UNROLL]; rand.x = rand.x <= p; rand.y = rand.y <= p; rand.z = rand.z <= p; rand.w = rand.w <= p; for (int ii = 0; ii < UNROLL; ii++) { IndexType li = linearIndex + blockDim.x * gridDim.x * ii; if (li < totalElements) { src[ii] = inputs[li]; add_src[ii] = add_inputs[li]; } } for (int ii = 0; ii < UNROLL; ii++) { IndexType li = linearIndex + blockDim.x * gridDim.x * ii; if (li < totalElements) { accscalar_t int1 = src[ii] * (&rand.x)[ii] * pinv; outputs[li] = static_cast(static_cast(add_src[ii]) + int1); mask[li] = (uint8_t)(&rand.x)[ii]; } } __syncthreads(); } } template __global__ void apex_add_kernel(scalar_t const *inputs, scalar_t const *add_inputs, scalar_t *outputs, IndexType totalElements) { IndexType idx = blockIdx.x * blockDim.x + threadIdx.x; IndexType rounded_size = ((totalElements - 1) / (blockDim.x * gridDim.x * UNROLL) + 1) * blockDim.x * gridDim.x * UNROLL; for (IndexType linearIndex = idx; linearIndex < rounded_size; linearIndex += gridDim.x * blockDim.x * UNROLL) { scalar_t src[UNROLL]; scalar_t add_src[UNROLL]; for (int ii = 0; ii < UNROLL; ii++) { IndexType li = linearIndex + blockDim.x * gridDim.x * ii; if (li < totalElements) { src[ii] = inputs[li]; add_src[ii] = add_inputs[li]; } } for (int ii = 0; ii < UNROLL; ii++) { IndexType li = linearIndex + blockDim.x * gridDim.x * ii; if (li < totalElements) { outputs[li] = src[ii] + add_src[ii]; } } __syncthreads(); } } template __global__ void apex_masked_scale_kernel(scalar_t const *inputs, scalar_t *outputs, uint8_t const *mask, IndexType totalElements, accscalar_t scale) { IndexType idx = blockIdx.x * blockDim.x + threadIdx.x; IndexType rounded_size = ((totalElements - 1) / (blockDim.x * gridDim.x * UNROLL) + 1) * blockDim.x * gridDim.x * UNROLL; for (IndexType linearIndex = idx; linearIndex < rounded_size; linearIndex += gridDim.x * blockDim.x * UNROLL) { scalar_t src[UNROLL]; scalar_t msk[UNROLL]; for (int ii = 0; ii < UNROLL; ii++) { IndexType li = linearIndex + blockDim.x * gridDim.x * ii; if (li < totalElements) { src[ii] = static_cast(inputs[li]); msk[ii] = static_cast(mask[li]); } } for (int ii = 0; ii < UNROLL; ii++) { IndexType li = linearIndex + blockDim.x * gridDim.x * ii; if (li < totalElements) { outputs[li] = static_cast(src[ii]) * scale * static_cast(msk[ii]); } } } } template void apex_fused_dropout_cuda(scalar_t const *inputs, scalar_t *outputs, uint8_t *mask, IndexType totalElements, accscalar_t p) { auto gen = at::cuda::detail::getDefaultCUDAGenerator(); int block_size = 256; dim3 dim_block(block_size); dim3 grid((totalElements + block_size - 1) / block_size); unsigned int blocks_per_sm = at::cuda::getCurrentDeviceProperties()->maxThreadsPerMultiProcessor / block_size; grid.x = std::min((unsigned int)at::cuda::getCurrentDeviceProperties() ->multiProcessorCount * blocks_per_sm, grid.x); // number of times random will be generated per thread, to offset philox // counter in the random state int64_t counter_offset = ((totalElements - 1) / (block_size * grid.x * UNROLL) + 1) * UNROLL; std::pair rng_engine_inputs; { // See Note [Acquire lock when using random generators] std::lock_guard lock(gen.mutex()); rng_engine_inputs = at::check_generator(gen)->philox_engine_inputs( counter_offset); } apex_fused_dropout_kernel <<>>( inputs, outputs, mask, totalElements, p, rng_engine_inputs); C10_CUDA_CHECK(cudaGetLastError()); } template void apex_dropout_add_cuda(scalar_t const *inputs, scalar_t const *add_inputs, scalar_t *outputs, uint8_t *mask, IndexType totalElements, accscalar_t p) { auto gen = at::cuda::detail::getDefaultCUDAGenerator(); int block_size = 256; dim3 dim_block(block_size); dim3 grid((totalElements + block_size - 1) / block_size); unsigned int blocks_per_sm = at::cuda::getCurrentDeviceProperties()->maxThreadsPerMultiProcessor / block_size; grid.x = std::min((unsigned int)at::cuda::getCurrentDeviceProperties() ->multiProcessorCount * blocks_per_sm, grid.x); // number of times random will be generated per thread, to offset philox // counter in the random state int64_t counter_offset = ((totalElements - 1) / (block_size * grid.x * UNROLL) + 1) * UNROLL; std::pair rng_engine_inputs; { // See Note [Acquire lock when using random generators] std::lock_guard lock(gen.mutex()); rng_engine_inputs = at::check_generator(gen)->philox_engine_inputs( counter_offset); } apex_dropout_add_kernel <<>>( inputs, add_inputs, outputs, mask, totalElements, p, rng_engine_inputs); C10_CUDA_CHECK(cudaGetLastError()); } template void apex_add_cuda(scalar_t const *inputs, scalar_t const *add_inputs, scalar_t *outputs, IndexType totalElements) { int block_size = 256; dim3 dim_block(block_size); dim3 grid((totalElements + block_size - 1) / block_size); unsigned int blocks_per_sm = at::cuda::getCurrentDeviceProperties()->maxThreadsPerMultiProcessor / block_size; grid.x = std::min((unsigned int)at::cuda::getCurrentDeviceProperties() ->multiProcessorCount * blocks_per_sm, grid.x); apex_add_kernel <<>>( inputs, add_inputs, outputs, totalElements); C10_CUDA_CHECK(cudaGetLastError()); } template void apex_masked_scale_cuda(scalar_t const *inputs, scalar_t *outputs, uint8_t const *mask, IndexType totalElements, accscalar_t scale) { int block_size = 256; dim3 dim_block(block_size); dim3 grid((totalElements + block_size - 1) / block_size); unsigned int blocks_per_sm = at::cuda::getCurrentDeviceProperties()->maxThreadsPerMultiProcessor / block_size; grid.x = std::min((unsigned int)at::cuda::getCurrentDeviceProperties() ->multiProcessorCount * blocks_per_sm, grid.x); apex_masked_scale_kernel <<>>( inputs, outputs, mask, totalElements, scale); C10_CUDA_CHECK(cudaGetLastError()); }