18 lines
459 B
Python
18 lines
459 B
Python
'''
|
|
Misc Utility functions
|
|
'''
|
|
from collections import OrderedDict
|
|
|
|
|
|
def convert_state_dict(state_dict):
|
|
"""Converts a state dict saved from a dataParallel module to normal
|
|
module state_dict inplace
|
|
:param state_dict is the loaded DataParallel model_state
|
|
|
|
"""
|
|
new_state_dict = OrderedDict()
|
|
for k, v in state_dict.items():
|
|
name = k[7:] # remove `module.`
|
|
new_state_dict[name] = v
|
|
return new_state_dict
|