/****************************************************************************** * Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * * Redistributions of source code must retain the above copyright * notice, this list of conditions and the following disclaimer. * * Redistributions in binary form must reproduce the above copyright * notice, this list of conditions and the following disclaimer in the * documentation and/or other materials provided with the distribution. * * Neither the name of the NVIDIA CORPORATION nor the * names of its contributors may be used to endorse or promote products * derived from this software without specific prior written permission. * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * ******************************************************************************/ #pragma once #include #include #include extern "C" __device__ uint32_t __nvvm_get_smem_pointer(void *ptr); //////////////////////////////////////////////////////////////////////////////////////////////////// namespace fmha { //////////////////////////////////////////////////////////////////////////////////////////////////// struct Row {}; struct Col {}; //////////////////////////////////////////////////////////////////////////////////////////////////// template< int M, bool = (M & (M-1)) == 0 > struct Next_power_of_two { }; template< int M > struct Next_power_of_two< M, true > { enum { VALUE = M }; }; template<> struct Next_power_of_two< 3, false> { enum { VALUE = 4 }; }; template<> struct Next_power_of_two< 5, false> { enum { VALUE = 8 }; }; template<> struct Next_power_of_two< 6, false> { enum { VALUE = 8 }; }; template<> struct Next_power_of_two< 7, false> { enum { VALUE = 8 }; }; template<> struct Next_power_of_two< 9, false> { enum { VALUE = 16 }; }; template<> struct Next_power_of_two< 10, false> { enum { VALUE = 16 }; }; template<> struct Next_power_of_two< 11, false> { enum { VALUE = 16 }; }; template<> struct Next_power_of_two< 12, false> { enum { VALUE = 16 }; }; template<> struct Next_power_of_two< 13, false> { enum { VALUE = 16 }; }; template<> struct Next_power_of_two< 14, false> { enum { VALUE = 16 }; }; template<> struct Next_power_of_two< 15, false> { enum { VALUE = 16 }; }; template<> struct Next_power_of_two< 24, false> { enum { VALUE = 32 }; }; template<> struct Next_power_of_two< 48, false> { enum { VALUE = 64 }; }; template<> struct Next_power_of_two< 80, false> { enum { VALUE = 128 }; }; template<> struct Next_power_of_two< 96, false> { enum { VALUE = 128 }; }; template<> struct Next_power_of_two<112, false> { enum { VALUE = 128 }; }; template<> struct Next_power_of_two<144, false> { enum { VALUE = 256 }; }; //////////////////////////////////////////////////////////////////////////////////////////////////// template< int N, bool = (N & (N-1)) == 0 > struct Prev_power_of_two { }; template< int N > struct Prev_power_of_two< N, true > { enum { VALUE = N }; }; template<> struct Prev_power_of_two< 3, false> { enum { VALUE = 2 }; }; template<> struct Prev_power_of_two< 5, false> { enum { VALUE = 4 }; }; template<> struct Prev_power_of_two< 6, false> { enum { VALUE = 4 }; }; template<> struct Prev_power_of_two< 7, false> { enum { VALUE = 4 }; }; //////////////////////////////////////////////////////////////////////////////////////////////////// template< int M, int N > struct Div_up { enum { VALUE = (M + N-1) / N }; }; //////////////////////////////////////////////////////////////////////////////////////////////////// template< int A, int B > struct Max { enum { VALUE = A >= B ? A : B }; }; //////////////////////////////////////////////////////////////////////////////////////////////////// template< int A, int B, int C > struct Max_3 { enum { VALUE = Max::VALUE, C>::VALUE }; }; //////////////////////////////////////////////////////////////////////////////////////////////////// template< int A, int B > struct Min { enum { VALUE = A <= B ? A : B }; }; //////////////////////////////////////////////////////////////////////////////////////////////////// template< int SIZE_IN_BYTES > struct Uint_from_size_in_bytes { }; //////////////////////////////////////////////////////////////////////////////////////////////////// template<> struct Uint_from_size_in_bytes<1> { using Type = uint8_t; }; //////////////////////////////////////////////////////////////////////////////////////////////////// template<> struct Uint_from_size_in_bytes<2> { using Type = uint16_t; }; //////////////////////////////////////////////////////////////////////////////////////////////////// template<> struct Uint_from_size_in_bytes<4> { using Type = uint32_t; }; //////////////////////////////////////////////////////////////////////////////////////////////////// template<> struct Uint_from_size_in_bytes<8> { using Type = uint2; }; //////////////////////////////////////////////////////////////////////////////////////////////////// template<> struct Uint_from_size_in_bytes<16> { using Type = uint4; }; //////////////////////////////////////////////////////////////////////////////////////////////////// template< int WARPS_M, int WARPS_N, int WARPS_K > struct Warp_masks { }; //////////////////////////////////////////////////////////////////////////////////////////////////// template<> struct Warp_masks<8, 1, 1> { enum { M = 0xe0, N = 0x00, K = 0x00 }; }; template<> struct Warp_masks<4, 2, 1> { enum { M = 0x60, N = 0x80, K = 0x00 }; }; template<> struct Warp_masks<4, 1, 2> { enum { M = 0x60, N = 0x00, K = 0x80 }; }; template<> struct Warp_masks<4, 1, 1> { enum { M = 0x60, N = 0x00, K = 0x00 }; }; template<> struct Warp_masks<2, 4, 1> { enum { M = 0x20, N = 0xc0, K = 0x00 }; }; template<> struct Warp_masks<2, 2, 2> { enum { M = 0x20, N = 0x40, K = 0x80 }; }; template<> struct Warp_masks<2, 2, 1> { enum { M = 0x20, N = 0x40, K = 0x00 }; }; template<> struct Warp_masks<2, 1, 2> { enum { M = 0x20, N = 0x00, K = 0x40 }; }; template<> struct Warp_masks<2, 1, 1> { enum { M = 0x20, N = 0x00, K = 0x00 }; }; template<> struct Warp_masks<1, 8, 1> { enum { M = 0x00, N = 0xe0, K = 0x00 }; }; template<> struct Warp_masks<1, 4, 2> { enum { M = 0x00, N = 0x60, K = 0x80 }; }; template<> struct Warp_masks<1, 4, 1> { enum { M = 0x00, N = 0x60, K = 0x00 }; }; template<> struct Warp_masks<1, 2, 2> { enum { M = 0x00, N = 0x20, K = 0x40 }; }; template<> struct Warp_masks<1, 2, 1> { enum { M = 0x00, N = 0x20, K = 0x00 }; }; template<> struct Warp_masks<1, 1, 4> { enum { M = 0x00, N = 0x00, K = 0x60 }; }; template<> struct Warp_masks<1, 1, 2> { enum { M = 0x00, N = 0x00, K = 0x20 }; }; template<> struct Warp_masks<1, 1, 1> { enum { M = 0x00, N = 0x00, K = 0x00 }; }; //////////////////////////////////////////////////////////////////////////////////////////////////// template< typename T > inline __device__ __host__ T div_up(T m, T n) { return (m + n-1) / n; } //////////////////////////////////////////////////////////////////////////////////////////////////// inline int clz(int x) { for( int i = 31; i >= 0; --i ) { if( (1 << i) & x ) { return 31 - i; } } return 32; } //////////////////////////////////////////////////////////////////////////////////////////////////// inline int find_log_2(int x, bool round_up = false) { int a = 31 - clz(x); if( round_up ) { a += (x & (x-1)) ? 1 : 0; } return a; } //////////////////////////////////////////////////////////////////////////////////////////////////// static inline __device__ uint32_t hadd2(uint32_t a, uint32_t b) { uint32_t c; asm volatile("add.f16x2 %0, %1, %2;\n" : "=r"(c) : "r"(a), "r"(b)); return c; } //////////////////////////////////////////////////////////////////////////////////////////////////// static inline __device__ uint32_t hmin2(uint32_t a, uint32_t b) { uint32_t c; asm volatile("min.f16x2 %0, %1, %2;" : "=r"(c) : "r"(a), "r"(b)); return c; } //////////////////////////////////////////////////////////////////////////////////////////////////// static inline __device__ uint32_t hmul2(uint32_t a, uint32_t b) { uint32_t c; asm volatile("mul.f16x2 %0, %1, %2;\n" : "=r"(c) : "r"(a), "r"(b)); return c; } //////////////////////////////////////////////////////////////////////////////////////////////////// static inline __device__ uint2 hmul4(uint2 a, uint2 b) { uint2 c; c.x = hmul2(a.x, b.x); c.y = hmul2(a.y, b.y); return c; } //////////////////////////////////////////////////////////////////////////////////////////////////// static inline __device__ uint4 hmul8(uint4 a, uint4 b) { uint4 c; c.x = hmul2(a.x, b.x); c.y = hmul2(a.y, b.y); c.z = hmul2(a.z, b.z); c.w = hmul2(a.w, b.w); return c; } //////////////////////////////////////////////////////////////////////////////////////////////////// static inline __device__ uint4 hmul8(uint32_t a, uint4 b) { uint4 c; c.x = hmul2(a, b.x); c.y = hmul2(a, b.y); c.z = hmul2(a, b.z); c.w = hmul2(a, b.w); return c; } //////////////////////////////////////////////////////////////////////////////////////////////////// static inline __device__ uint32_t hrelu2(uint32_t x, uint32_t lb = 0) { uint32_t res; #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 asm volatile( "max.f16x2 %0, %1, %2;\n" : "=r"(res) : "r"(x), "r"(lb)); #else const uint32_t zero = 0u; asm volatile( \ "{\n" \ "\t .reg .f16x2 sela;\n" \ "\t set.gtu.u32.f16x2 sela, %1, %2;\n" \ "\t and.b32 %0, sela, %1;\n" "}\n" : "=r"(res) : "r"(x), "r"(zero)); #endif return res; } static inline __device__ uint32_t habs2(uint32_t x) { uint32_t res; asm volatile( "abs.f16x2 %0, %1;\n" : "=r"(res) : "r"(x)); return res; } //////////////////////////////////////////////////////////////////////////////////////////////////// // template< typename T > static inline __device__ T clamp(T x, T lb, T ub) { return x < lb ? lb : (x > ub ? ub : x); } //////////////////////////////////////////////////////////////////////////////////////////////////// static inline __device__ uint16_t clamp_to_zero(uint16_t x) { uint16_t mask; asm volatile("set.gtu %0, %1, 0;" : "=h"(mask) : "h"(x)); return mask & x; } //////////////////////////////////////////////////////////////////////////////////////////////////// static inline __device__ uint16_t float_to_half(float f) { uint16_t h; asm volatile("cvt.rn.f16.f32 %0, %1;" : "=h"(h) : "f"(f)); return h; } //////////////////////////////////////////////////////////////////////////////////////////////////// static inline __device__ uint32_t float2_to_half2(float a, float b) { uint32_t c; #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 asm volatile("cvt.rn.f16x2.f32 %0, %1, %2;\n" : "=r"(c) : "f"(b), "f"(a)); #else uint16_t lo = float_to_half(a); uint16_t hi = float_to_half(b); asm volatile("mov.b32 %0, {%1, %2};\n" : "=r"(c) : "h"(lo), "h"(hi)); #endif return c; } //////////////////////////////////////////////////////////////////////////////////////////////////// static inline __device__ uint32_t float_to_half2(float a) { return float2_to_half2(a,a); } //////////////////////////////////////////////////////////////////////////////////////////////////// static inline __device__ uint32_t float2_to_half2(const float2 &f) { return float2_to_half2(f.x, f.y); } //////////////////////////////////////////////////////////////////////////////////////////////////// static inline __device__ uint2 float4_to_half4(float x, float y, float z, float w) { uint2 d; d.x = float2_to_half2(x, y); d.y = float2_to_half2(z, w); return d; } //////////////////////////////////////////////////////////////////////////////////////////////////// static inline __device__ uint32_t hfma2(uint32_t a, uint32_t b, uint32_t c) { uint32_t d; asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(d) : "r"(a), "r"(b), "r"(c)); return d; } //////////////////////////////////////////////////////////////////////////////////////////////////// static inline __device__ uint32_t hfma2_relu(uint32_t a, uint32_t b, uint32_t c) { uint32_t d; #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 asm volatile("fma.rn.f16x2.relu %0, %1, %2, %3;" : "=r"(d) : "r"(a), "r"(b), "r"(c)); #else d = hrelu2(hfma2(a, b, c)); #endif return d; } //////////////////////////////////////////////////////////////////////////////////////////////////// static inline __device__ uint32_t h0_h0(uint32_t x) { uint32_t y; asm volatile("{.reg .f16 lo, hi; mov.b32 {lo, hi}, %1; mov.b32 %0, {lo, lo};}\n" : "=r"(y) : "r"(x)); return y; } //////////////////////////////////////////////////////////////////////////////////////////////////// static inline __device__ float h0_to_float(uint32_t h2) { float f; asm volatile("{\n" \ ".reg .f16 lo, hi;\n" \ "mov.b32 {lo, hi}, %1;\n" \ "cvt.f32.f16 %0, lo;\n" \ "}\n" : "=f"(f) : "r"(h2)); return f; } //////////////////////////////////////////////////////////////////////////////////////////////////// static inline __device__ uint32_t h1_h1(uint32_t x) { uint32_t y; asm volatile("{.reg .f16 lo, hi; mov.b32 {lo, hi}, %1; mov.b32 %0, {hi, hi};}\n" : "=r"(y) : "r"(x)); return y; } //////////////////////////////////////////////////////////////////////////////////////////////////// static inline __device__ uint16_t hadd(uint16_t a, uint16_t b) { uint16_t d; asm volatile("add.f16 %0, %1, %2;" : "=h"(d) : "h"(a), "h"(b)); return d; } //////////////////////////////////////////////////////////////////////////////////////////////////// static inline __device__ uint32_t hadd(uint32_t a, uint32_t b) { return hadd2(a, b); } //////////////////////////////////////////////////////////////////////////////////////////////////// static inline __device__ uint2 hadd4(uint2 a, uint2 b) { uint2 c; c.x = hadd2(a.x, b.x); c.y = hadd2(a.y, b.y); return c; } //////////////////////////////////////////////////////////////////////////////////////////////////// static inline __device__ uint2 hadd(uint2 a, uint2 b) { return hadd4(a, b); } //////////////////////////////////////////////////////////////////////////////////////////////////// static inline __device__ uint4 hadd8(uint4 a, uint4 b) { uint4 c; c.x = hadd2(a.x, b.x); c.y = hadd2(a.y, b.y); c.z = hadd2(a.z, b.z); c.w = hadd2(a.w, b.w); return c; } //////////////////////////////////////////////////////////////////////////////////////////////////// static inline __device__ uint4 fadd4(uint4 a, uint4 b) { float4 c; c.x = reinterpret_cast(a.x) + reinterpret_cast(b.x); c.y = reinterpret_cast(a.y) + reinterpret_cast(b.y); c.z = reinterpret_cast(a.z) + reinterpret_cast(b.z); c.w = reinterpret_cast(a.w) + reinterpret_cast(b.w); return reinterpret_cast(c); } //////////////////////////////////////////////////////////////////////////////////////////////////// static inline __device__ uint4 hadd(uint4 a, uint4 b) { return hadd8(a, b); } //////////////////////////////////////////////////////////////////////////////////////////////////// static inline __device__ float half_to_float(uint16_t h) { float f; asm volatile("cvt.f32.f16 %0, %1;\n" : "=f"(f) : "h"(h)); return f; } //////////////////////////////////////////////////////////////////////////////////////////////////// static inline __device__ float2 half2_to_float2(uint32_t x) { uint16_t lo, hi; asm volatile("mov.b32 {%0, %1}, %2;\n" : "=h"(lo), "=h"(hi) : "r"(x)); return make_float2(half_to_float(lo), half_to_float(hi)); } //////////////////////////////////////////////////////////////////////////////////////////////////// static inline __device__ void half2_to_float2(float &x, float &y, uint32_t h) { float2 tmp = half2_to_float2(h); x = tmp.x; y = tmp.y; } //////////////////////////////////////////////////////////////////////////////////////////////////// static inline __device__ uint16_t hfma(uint16_t a, uint16_t b, uint16_t c) { uint16_t d; asm volatile("fma.rn.f16 %0, %1, %2, %3;" : "=h"(d) : "h"(a), "h"(b), "h"(c)); return d; } //////////////////////////////////////////////////////////////////////////////////////////////////// static inline __device__ uint16_t hmul(uint16_t a, uint16_t b) { uint16_t d; asm volatile("mul.f16 %0, %1, %2;" : "=h"(d) : "h"(a), "h"(b)); return d; } //////////////////////////////////////////////////////////////////////////////////////////////////// static inline __device__ float sigmoid(float x) { return 1.f / (1.f + expf(-x)); } //////////////////////////////////////////////////////////////////////////////////////////////////// inline __device__ void clear(uint16_t &dst) { dst = uint16_t(0); } //////////////////////////////////////////////////////////////////////////////////////////////////// inline __device__ void clear(uint32_t &dst) { dst = 0u; } //////////////////////////////////////////////////////////////////////////////////////////////////// inline __device__ void clear(uint2 &dst) { dst = make_uint2(0u, 0u); } //////////////////////////////////////////////////////////////////////////////////////////////////// inline __device__ void clear(uint4 &dst) { dst = make_uint4(0u, 0u, 0u, 0u); } //////////////////////////////////////////////////////////////////////////////////////////////////// // // P R E D I C A T E P A C K I N G // //////////////////////////////////////////////////////////////////////////////////////////////////// enum { BYTES_PER_REG = 4, PREDS_PER_BYTE = 4, PREDS_PER_REG = BYTES_PER_REG * PREDS_PER_BYTE }; //////////////////////////////////////////////////////////////////////////////////////////////////// // // G E N E R I C P R E D I C A T E D L D G S T S // //////////////////////////////////////////////////////////////////////////////////////////////////// template< int N, int M, typename Functor > inline __device__ void load_(Functor &fct, const uint32_t (&preds)[M]) { // The number of complete bytes (where we use all the predicates in a byte). enum { COMPLETE = N / PREDS_PER_BYTE }; // Make sure we did allocate enough predicates. static_assert(Div_up::VALUE <= M, ""); // The remainder. enum { REMAINDER = N - COMPLETE * PREDS_PER_BYTE }; // Make sure we got the math right and the remainder is between 0 and 3. static_assert(REMAINDER >= 0 && REMAINDER <= 3, ""); // The mask to extract the predicates. enum { COMPLETE_MASK = (1 << PREDS_PER_BYTE) - 1 }; // Clear the fetch registers. #pragma unroll for( int ii = 0; ii < N; ++ii ) { fct.clear(ii); } // Run complete steps. bool p[PREDS_PER_BYTE]; #pragma unroll for( int ii = 0; ii < COMPLETE; ++ii ) { // The predicate. uint32_t reg = preds[ii / BYTES_PER_REG]; // Extract the predicates. #pragma unroll for( int jj = 0; jj < PREDS_PER_BYTE; ++jj ) { uint32_t mask = 1u << (ii % BYTES_PER_REG * 8 + jj); p[jj] = (reg & mask) != 0u; } // Issue the loads. #pragma unroll for( int jj = 0; jj < PREDS_PER_BYTE; ++jj ) { fct.load(ii * PREDS_PER_BYTE + jj, p[jj]); } } // Skip the rest of the code if we do not have a remainder. if( REMAINDER > 0 ) { // The mask to extract the predicates. enum { REMAINDER_MASK = (1 << REMAINDER) - 1 }; // The predicate register. uint32_t reg = preds[COMPLETE / BYTES_PER_REG]; // Extract the predicates. #pragma unroll for( int jj = 0; jj < PREDS_PER_BYTE; ++jj ) { uint32_t mask = 1u << (COMPLETE % BYTES_PER_REG * 8 + jj); p[jj] = (reg & mask) != 0u; } // Issue the loads. #pragma unroll for( int ii = 0; ii < REMAINDER; ++ii ) { fct.load(COMPLETE * PREDS_PER_BYTE + ii, p[ii]); } } } //////////////////////////////////////////////////////////////////////////////////////////////////// template< int M, typename Functor > inline __device__ void load_(Functor &fct, uint32_t preds) { uint32_t tmp[1] = { preds }; load_(fct, tmp); } //////////////////////////////////////////////////////////////////////////////////////////////////// // // L D G // //////////////////////////////////////////////////////////////////////////////////////////////////// inline __device__ void ldg(uint8_t &dst, const void *ptr) { dst = *reinterpret_cast(ptr); } //////////////////////////////////////////////////////////////////////////////////////////////////// inline __device__ void ldg(uint16_t &dst, const void *ptr) { dst = *reinterpret_cast(ptr); } //////////////////////////////////////////////////////////////////////////////////////////////////// inline __device__ void ldg(uint32_t &dst, const void *ptr) { dst = *reinterpret_cast(ptr); } //////////////////////////////////////////////////////////////////////////////////////////////////// inline __device__ void ldg(uint2 &dst, const void *ptr) { dst = *reinterpret_cast(ptr); } //////////////////////////////////////////////////////////////////////////////////////////////////// inline __device__ void ldg(uint4 &dst, const void *ptr) { dst = *reinterpret_cast(ptr); } //////////////////////////////////////////////////////////////////////////////////////////////////// template< typename Data_type, int N > struct Ldg_functor { // Ctor. inline __device__ Ldg_functor(Data_type (&fetch)[N], const void* (&ptrs)[N]) : fetch_(fetch), ptrs_(ptrs) { } // Clear the element. inline __device__ void clear(int ii) { fmha::clear(fetch_[ii]); } // Trigger the loads. inline __device__ void load(int ii, bool p) { if( p ) { ldg(fetch_[ii], ptrs_[ii]); } } // The fetch registers. Data_type (&fetch_)[N]; // The pointers. const void* (&ptrs_)[N]; }; //////////////////////////////////////////////////////////////////////////////////////////////////// template< typename Data_type, int N, int M > inline __device__ void ldg_(Data_type (&fetch)[N], const void* (&ptrs)[N], uint32_t (&preds)[M]) { Ldg_functor fct(fetch, ptrs); load_(fct, preds); } //////////////////////////////////////////////////////////////////////////////////////////////////// template< int N, int M > inline __device__ void ldg(uint8_t (&fetch)[N], const void* (&ptrs)[N], uint32_t (&preds)[M]) { ldg_(fetch, ptrs, preds); } //////////////////////////////////////////////////////////////////////////////////////////////////// template< int N, int M > inline __device__ void ldg(uint16_t (&fetch)[N], const void* (&ptrs)[N], uint32_t (&preds)[M]) { ldg_(fetch, ptrs, preds); } //////////////////////////////////////////////////////////////////////////////////////////////////// template< int N, int M > inline __device__ void ldg(uint32_t (&fetch)[N], const void* (&ptrs)[N], uint32_t (&preds)[M]) { ldg_(fetch, ptrs, preds); } //////////////////////////////////////////////////////////////////////////////////////////////////// template< int N, int M > inline __device__ void ldg(uint2 (&fetch)[N], const void* (&ptrs)[N], uint32_t (&preds)[M]) { ldg_(fetch, ptrs, preds); } //////////////////////////////////////////////////////////////////////////////////////////////////// template< int N, int M > inline __device__ void ldg(uint4 (&fetch)[N], const void* (&ptrs)[N], uint32_t (&preds)[M]) { ldg_(fetch, ptrs, preds); } //////////////////////////////////////////////////////////////////////////////////////////////////// // // L D S // //////////////////////////////////////////////////////////////////////////////////////////////////// inline __device__ void lds(uint16_t &dst, uint32_t ptr) { asm volatile("ld.shared.b16 %0, [%1];\n" : "=h"(dst) : "r"(ptr)); } //////////////////////////////////////////////////////////////////////////////////////////////////// inline __device__ void lds(uint32_t &dst, uint32_t ptr) { asm volatile("ld.shared.b32 %0, [%1];\n" : "=r"(dst) : "r"(ptr)); } //////////////////////////////////////////////////////////////////////////////////////////////////// inline __device__ void lds(uint2 &dst, uint32_t ptr) { asm volatile("ld.shared.v2.b32 {%0, %1}, [%2];\n" : "=r"(dst.x), "=r"(dst.y) : "r"(ptr)); } //////////////////////////////////////////////////////////////////////////////////////////////////// inline __device__ void lds(uint4 &dst, uint32_t ptr) { asm volatile("ld.shared.v4.b32 {%0, %1, %2, %3}, [%4];\n" : "=r"(dst.x) , "=r"(dst.y) , "=r"(dst.z) , "=r"(dst.w) : "r"(ptr)); } //////////////////////////////////////////////////////////////////////////////////////////////////// // // L D S M // //////////////////////////////////////////////////////////////////////////////////////////////////// inline __device__ void ldsm(uint32_t &dst, uint32_t ptr) { #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 730 asm volatile("ldmatrix.sync.aligned.m8n8.x1.shared.b16 {%0}, [%1];\n" : "=r"(dst) : "r"(ptr)); #endif } //////////////////////////////////////////////////////////////////////////////////////////////////// inline __device__ void ldsmt(uint32_t &dst, uint32_t ptr) { #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 730 asm volatile("ldmatrix.sync.aligned.m8n8.x1.trans.shared.b16 {%0}, [%1];\n" : "=r"(dst) : "r"(ptr)); #endif } //////////////////////////////////////////////////////////////////////////////////////////////////// inline __device__ void ldsm(uint2 &dst, uint32_t ptr) { #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 730 asm volatile("ldmatrix.sync.aligned.m8n8.x2.shared.b16 {%0, %1}, [%2];\n" : "=r"(dst.x), "=r"(dst.y) : "r"(ptr)); #endif } //////////////////////////////////////////////////////////////////////////////////////////////////// inline __device__ void ldsmt(uint2 &dst, uint32_t ptr) { #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 730 asm volatile("ldmatrix.sync.aligned.m8n8.x2.trans.shared.b16 {%0, %1}, [%2];\n" : "=r"(dst.x), "=r"(dst.y) : "r"(ptr)); #endif } //////////////////////////////////////////////////////////////////////////////////////////////////// inline __device__ void ldsm(uint4 &dst, uint32_t ptr) { #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 730 asm volatile("ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0, %1, %2, %3}, [%4];\n" : "=r"(dst.x), "=r"(dst.y), "=r"(dst.z), "=r"(dst.w) : "r"(ptr)); #endif } //////////////////////////////////////////////////////////////////////////////////////////////////// inline __device__ void ldsmt(uint4 &dst, uint32_t ptr) { #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 730 asm volatile("ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16 {%0, %1, %2, %3}, [%4];\n" : "=r"(dst.x), "=r"(dst.y), "=r"(dst.z), "=r"(dst.w) : "r"(ptr)); #endif } //////////////////////////////////////////////////////////////////////////////////////////////////// // // S T G // //////////////////////////////////////////////////////////////////////////////////////////////////// inline __device__ void stg(void *ptr, uint8_t val) { *reinterpret_cast(ptr) = val; } //////////////////////////////////////////////////////////////////////////////////////////////////// inline __device__ void stg(void *ptr, uint16_t val) { *reinterpret_cast(ptr) = val; } //////////////////////////////////////////////////////////////////////////////////////////////////// inline __device__ void stg(void *ptr, uint32_t val) { *reinterpret_cast(ptr) = val; } //////////////////////////////////////////////////////////////////////////////////////////////////// inline __device__ void stg(void *ptr, uint2 val) { *reinterpret_cast(ptr) = val; } //////////////////////////////////////////////////////////////////////////////////////////////////// inline __device__ void stg(void *ptr, uint4 val) { *reinterpret_cast(ptr) = val; } //////////////////////////////////////////////////////////////////////////////////////////////////// // // S T S // //////////////////////////////////////////////////////////////////////////////////////////////////// inline __device__ void sts(uint32_t ptr, uint16_t val) { asm volatile("st.shared.b16 [%0], %1;\n" : : "r"(ptr), "h"(val)); } //////////////////////////////////////////////////////////////////////////////////////////////////// inline __device__ void sts(uint32_t ptr, uint32_t val) { asm volatile("st.shared.b32 [%0], %1;\n" : : "r"(ptr), "r"(val)); } //////////////////////////////////////////////////////////////////////////////////////////////////// inline __device__ void sts(uint32_t ptr, uint2 val) { asm volatile("st.shared.v2.b32 [%0], {%1, %2};\n" : : "r"(ptr) , "r"(val.x) , "r"(val.y)); } //////////////////////////////////////////////////////////////////////////////////////////////////// inline __device__ void sts(uint32_t ptr, uint4 val) { asm volatile("st.shared.v4.b32 [%0], {%1, %2, %3, %4};\n" : : "r"(ptr) , "r"(val.x) , "r"(val.y) , "r"(val.z) , "r"(val.w)); } //////////////////////////////////////////////////////////////////////////////////////////////////// template< typename Data_type, int N > inline __device__ void sts_(uint32_t (&ptrs)[N], const Data_type (&data)[N]) { #pragma unroll for( int ii = 0; ii < N; ++ii ) { sts(ptrs[ii], data[ii]); } } //////////////////////////////////////////////////////////////////////////////////////////////////// template< int N > inline __device__ void sts(uint32_t (&ptrs)[N], const uint16_t (&data)[N]) { sts_(ptrs, data); } //////////////////////////////////////////////////////////////////////////////////////////////////// template< int N > inline __device__ void sts(uint32_t (&ptrs)[N], const uint32_t (&data)[N]) { sts_(ptrs, data); } //////////////////////////////////////////////////////////////////////////////////////////////////// template< int N > inline __device__ void sts(uint32_t (&ptrs)[N], const uint2 (&data)[N]) { sts_(ptrs, data); } //////////////////////////////////////////////////////////////////////////////////////////////////// template< int N > inline __device__ void sts(uint32_t (&ptrs)[N], const uint4 (&data)[N]) { sts_(ptrs, data); } //////////////////////////////////////////////////////////////////////////////////////////////////// } // namespace fmha