import torch import transducer_loss_cuda import transducer_joint_cuda class TransducerJoint(torch.nn.Module): """Transducer joint Detail of this loss function can be found in: Sequence Transduction with Recurrent Neural Networks Arguments: pack_output (bool, optional): whether to pack the output in a compact form with don't-care data being removed. (default: False) relu (bool, optional): apply ReLU to the output of the joint operation. Requires opt=1 (default: False) dropout (bool, optional): apply dropout to the output of the joint operation. Requires opt=1 (default: False) opt (int, optional): pick the optimization level in [0, 1]. opt=1 picks a tiled algorithm. (default: 1) fwd_tile_size (int, optional): tile size used in forward operation. This argument will be ignored if opt != 1. (default: 4) dropout_prob (float, optional): dropout probability. (default: 0.0) probe_mask (bool, optional): a flag used to probe the mask generated by ReLU and/or dropout operation. When this argument is set to True, the mask can be accessed through self.mask_probe. (default: false) """ def __init__(self, pack_output=False, relu=False, dropout=False, opt=1, fwd_tile_size=4, dropout_prob=0, probe_mask=False): super(TransducerJoint, self).__init__() self.pack_output = pack_output self.relu = relu self.dropout = dropout self.dropout_prob = dropout_prob self.opt = opt self.fwd_tile_size = fwd_tile_size self.dummy_batch_offset = torch.empty(0) masked = self.relu or self.dropout self.mask_probe = [] if masked and probe_mask else None if masked and opt != 1: raise NotImplementedError("ReLU and dropout fusion is only supported with opt=1") def forward(self, f, g, f_len, g_len, batch_offset=None, packed_batch=0): """Forward operation of transducer joint Arguments: f (tensor): transcription vector from encode block of shape (B, T, H). g (tensor): prediction vector form predict block of shape (B, U, H). f_len (tensor): length of transcription vector for each batch. g_len (tensor): length of prediction vector minus 1 for each batch. batch_offset (tensor, optional): tensor containing the offset of each batch in the results. For example, batch offset can be obtained from: batch_offset = torch.cumsum(f_len*g_len, dim=0) This argument is required if pack_output == True, and is ignored if pack_output == False. (default: None) packed_batch (int, optional): the batch size after packing. This argument is ignored if pack_output == False. (default: 0) """ my_batch_offset = batch_offset if self.pack_output else self.dummy_batch_offset if self.pack_output and (batch_offset is None or packed_batch == 0): raise Exception("Please specify batch_offset and packed_batch when packing is enabled") dropout = self.dropout and self.training # only dropout for training return TransducerJointFunc.apply(f, g, f_len, g_len, self.pack_output, self.relu, dropout, my_batch_offset, packed_batch, self.opt, self.fwd_tile_size, self.dropout_prob, self.mask_probe) class TransducerLoss(torch.nn.Module): """Transducer loss Detail of this loss function can be found in: Sequence Transduction with Recurrent Neural Networks Arguments: fuse_softmax_backward (bool, optional) whether to fuse the backward of transducer loss with softmax. (default: True) opt (int, optional): pick the optimization level in [0, 1]. opt=1 picks a more optimized algorithm. In some cases, opt=1 might fall back to opt=0. (default: 1) packed_input (bool, optional): whether to pack the output in a compact form with don't-care data being removed. (default: False) """ def __init__(self, fuse_softmax_backward=True, opt=1, packed_input=False): super(TransducerLoss, self).__init__() self.fuse_softmax_backward = fuse_softmax_backward self.opt = opt self.packed_input = packed_input self.dummy_batch_offset = torch.empty(0) def forward(self, x, label, f_len, y_len, blank_idx, batch_offset=None, max_f_len=None, debug_list=None): """Forward operation of transducer joint Arguments: x (tensor): input tensor to the loss function with a shape of (B, T, U, H). label (tensor): labels for the input data. f_len (tensor): lengths of the inputs in the time dimension for each batch. y_len (tensor): lengths of the labels for each batch. blank_idx (int): index for the null symbol. batch_offset (tensor, optional): tensor containing the offset of each batch in the input. For example, batch offset can be obtained from: batch_offset = torch.cumsum(f_len*(y_len+1), dim=0) This argument is required if packed_input == True, and is ignored if packed_input == False. (default: None) max_f_len (int, optional): maximum length of the input in the time dimension. For example, it can be obtained as max_f_len = max(f_len) This argument is required if packed_input == True, and is ignored if packed_input == False. (default: None) (default: None) debug_list (list, optional): when an empty list is supplied, Alpha and Beta generated in the forward operation will be attached to this list for debug purpose. (default: None) """ if self.packed_input: if batch_offset is None or max_f_len is None: raise Exception("Please specify batch_offset and max_f_len when packing is \ enabled") my_batch_offset = batch_offset my_max_f_len = max_f_len else: my_batch_offset = self.dummy_batch_offset my_max_f_len = x.size(1) return TransducerLossFunc.apply(x, label, f_len, y_len, my_batch_offset, my_max_f_len, blank_idx, self.fuse_softmax_backward, debug_list, self.opt, self.packed_input) class TransducerLossFunc(torch.autograd.Function): @staticmethod def forward(ctx, x, label, f_len, y_len, batch_offset, max_f_len, blank_idx, fuse_softmax_backward, debug_list, opt, packed_input): if fuse_softmax_backward == False: with torch.enable_grad(): x = torch.nn.functional.log_softmax(x, dim=-1) else: x = torch.nn.functional.log_softmax(x, dim=-1) alpha, beta, loss = transducer_loss_cuda.forward( x, label, f_len, y_len, batch_offset, max_f_len, blank_idx, opt, packed_input) if debug_list == []: debug_list += [alpha, beta] ctx.save_for_backward(x, alpha, beta, f_len, y_len, label, batch_offset) ctx.blank_idx = blank_idx ctx.fuse_softmax_backward = fuse_softmax_backward ctx.opt = opt ctx.packed_input = packed_input ctx.max_f_len = max_f_len return loss @staticmethod def backward(ctx, loss_grad): x, alpha, beta, f_len, y_len, label, batch_offset = ctx.saved_tensors x_grad = transducer_loss_cuda.backward( x, loss_grad, alpha, beta, f_len, y_len, label, batch_offset, ctx.max_f_len, ctx.blank_idx, ctx.opt, ctx.fuse_softmax_backward, ctx.packed_input) if ctx.fuse_softmax_backward == False: x_grad = x.backward(x_grad) return x_grad, None, None, None, None, None, None, None, None, None, None class TransducerJointFunc(torch.autograd.Function): @staticmethod def forward(ctx, f, g, f_len, g_len, pack_output, relu, dropout, batch_offset, packed_batch, opt, fwd_tile_size, dropout_prob, mask_probe): h = transducer_joint_cuda.forward(f, g, f_len, g_len, batch_offset, packed_batch, opt, pack_output, relu, dropout, dropout_prob, fwd_tile_size) masked = relu or dropout if masked: ctx.save_for_backward(h[1], f_len, g_len, batch_offset) if mask_probe is not None: mask_probe.append(h[1]) else: ctx.save_for_backward(f_len, g_len, batch_offset) ctx.pack_output = pack_output ctx.masked = relu or dropout ctx.max_f_len = f.size(1) ctx.max_g_len = g.size(1) ctx.scale = 1 / (1-dropout_prob) if dropout and dropout_prob != 1 else 1 return h[0] @staticmethod def backward(ctx, loss_grad): if ctx.masked: mask, f_len, g_len, batch_offset = ctx.saved_tensors inp = [loss_grad, mask] else: f_len, g_len, batch_offset = ctx.saved_tensors inp = [loss_grad] f_grad, g_grad = transducer_joint_cuda.backward( inp, f_len, g_len, batch_offset, ctx.max_f_len, ctx.max_g_len, ctx.pack_output, ctx.scale) return f_grad, g_grad, None, None, None, None, None, None, None, None, None, None, None, \ None, None, None