# Autogenerated via: # python ~/code/watch/dev/maintain/mirror_package_geowatch.py def __getattr__(key): import geowatch.tasks.rutgers_material_change_detection.models.discritizers as mirror return getattr(mirror, key) def __dir__(): import geowatch.tasks.rutgers_material_change_detection.models.discritizers as mirror return dir(mirror) if __name__ == '__main__': batch_size, n_frames, n_tokens, token_dim = 4, 5, 25, 100 test_feats = torch.randn([batch_size, n_frames, n_tokens, token_dim]) n_classes = 30 norm = 'l2' discritizer = ResidualDiscritizer(n_classes, token_dim, norm=norm) output_feats = discritizer(test_feats)