#pragma once #include #include #include namespace layer_norm { //////////////////////////////////////////////////////////////////////////////////////////////////// template struct LaunchParams{ size_t workspace_bytes; size_t barrier_size; cudaDeviceProp * props; cudaStream_t stream; Params params; }; //////////////////////////////////////////////////////////////////////////////////////////////////// struct ParamsBase { ParamsBase() : ctas_per_col(0) , rows(0) , cols(0) , x(nullptr) , mu(nullptr) , rs(nullptr) , gamma(nullptr) , workspace(nullptr) , barrier(nullptr) { } // For Multi-CTA, number of different CTA groups. Otherwise same as gridDim.x. int ctas_per_col; // Input is interpreted as matrix. We normalize across columns. int rows; int cols; // Common data pointers. void *x; void *mu; void *rs; void *gamma; // Multi-CTA workspace in gmem. void *workspace; // Multi-CTA sync barriers in gmem. int *barrier; }; //////////////////////////////////////////////////////////////////////////////////////////////////// struct FwdParams : public ParamsBase { FwdParams() : ParamsBase() , z(nullptr) , beta(nullptr) , epsilon(0.f) { } // Output of LN FWD. void *z; void *beta; float epsilon; }; //////////////////////////////////////////////////////////////////////////////////////////////////// struct BwdParams : public ParamsBase { BwdParams() : ParamsBase() , dz(nullptr) , dbeta_part(nullptr) , dgamma_part(nullptr) , dx(nullptr) , dbeta(nullptr) , dgamma(nullptr) { } // Input: gradient wrt. LN FWD output. void *dz; // Workspace for Wgrad pre-reduction. void *dbeta_part; void *dgamma_part; // Output: Dgrad. void *dx; // Output: Wgrad. void *dbeta; void *dgamma; }; //////////////////////////////////////////////////////////////////////////////////////////////////// using FwdFunction = std::function&, const bool)>; using BwdFunction = std::function&, const bool)>; using FunctionKey = uint64_t; using FwdRegistry = std::unordered_map; using BwdRegistry = std::unordered_map; extern FwdRegistry FWD_FUNCS; extern BwdRegistry BWD_FUNCS; //////////////////////////////////////////////////////////////////////////////////////////////////// using fp32 = float; using fp16 = half; using bf16 = nv_bfloat16; //////////////////////////////////////////////////////////////////////////////////////////////////// template struct TypeId{}; template<> struct TypeId{ constexpr static uint32_t Value = 0; }; template<> struct TypeId{ constexpr static uint32_t Value = 1; }; template<> struct TypeId{ constexpr static uint32_t Value = 2; }; //////////////////////////////////////////////////////////////////////////////////////////////////// template struct Type2Key{ constexpr static uint32_t Value = TypeId::Value << S; }; //////////////////////////////////////////////////////////////////////////////////////////////////// template struct WeightType2Key : public Type2Key{}; template struct InputType2Key : public Type2Key{}; template struct OutputType2Key : public Type2Key{}; template struct ComputeType2Key : public Type2Key{}; //////////////////////////////////////////////////////////////////////////////////////////////////// template struct Types2Key{ constexpr static uint32_t Value = WeightType2Key::Value | InputType2Key::Value | OutputType2Key::Value | ComputeType2Key::Value; constexpr static inline uint64_t get(const uint64_t hidden_size){ constexpr uint64_t type_key = Value; return (type_key << 32) | hidden_size; } }; //////////////////////////////////////////////////////////////////////////////////////////////////// template struct FwdRegistrar{ FwdRegistrar(FwdFunction f){ uint64_t key = Types2Key::get(HIDDEN_SIZE); FWD_FUNCS.insert({ key, f }); } }; //////////////////////////////////////////////////////////////////////////////////////////////////// template struct BwdRegistrar{ BwdRegistrar(BwdFunction f){ uint64_t key = Types2Key::get(HIDDEN_SIZE); BWD_FUNCS.insert({ key, f }); } }; //////////////////////////////////////////////////////////////////////////////////////////////////// } // namespace layer_norm