from dewarp.models.densenetccnl import dnetccnl from dewarp.models.unetnc import UnetGenerator def get_model(name, n_classes=1, in_channels=3): model = _get_model_instance(name) if name == 'dnetccnl': model = model(img_size=128, in_channels=in_channels, out_channels=n_classes, filters=32) elif name == 'unetnc': model = model(input_nc=in_channels, output_nc=n_classes, num_downs=7) else: model = model(n_classes=n_classes) return model def _get_model_instance(name): try: return { 'dnetccnl': dnetccnl, 'unetnc': UnetGenerator, }[name] except: print('Model {} not available'.format(name))