25 lines
695 B
Python
25 lines
695 B
Python
from dewarp.models.densenetccnl import dnetccnl
|
|
from dewarp.models.unetnc import UnetGenerator
|
|
|
|
|
|
def get_model(name, n_classes=1, in_channels=3):
|
|
model = _get_model_instance(name)
|
|
|
|
if name == 'dnetccnl':
|
|
model = model(img_size=128, in_channels=in_channels, out_channels=n_classes, filters=32)
|
|
elif name == 'unetnc':
|
|
model = model(input_nc=in_channels, output_nc=n_classes, num_downs=7)
|
|
else:
|
|
model = model(n_classes=n_classes)
|
|
return model
|
|
|
|
|
|
def _get_model_instance(name):
|
|
try:
|
|
return {
|
|
'dnetccnl': dnetccnl,
|
|
'unetnc': UnetGenerator,
|
|
}[name]
|
|
except:
|
|
print('Model {} not available'.format(name))
|