# Autogenerated via: # python ~/code/watch/dev/maintain/mirror_package_geowatch.py from geowatch.tasks.rutgers_material_change_detection.models.siamese_fusion_model import print, print def __getattr__(key): import geowatch.tasks.rutgers_material_change_detection.models.siamese_fusion_model as mirror return getattr(mirror, key) def __dir__(): import geowatch.tasks.rutgers_material_change_detection.models.siamese_fusion_model as mirror return dir(mirror) if __name__ == '__main__': # Pass dummy data through model. F, C, H, W = 2, 4, 250, 250 data = torch.zeros([F, C, H, W]) model = SiameseFusion(input_size=H, num_channels=C, decoder_type="transpose_conv") feats = model.encode(data[0].unsqueeze(0), data[-1].unsqueeze(0)) output = model.decode(feats, H, W) print("Input shape: ", [C, H, W]) print("Output shape: ", output.shape)