58 lines
1.5 KiB
Python
58 lines
1.5 KiB
Python
import argparse
|
|
import glob
|
|
import os
|
|
import random
|
|
from pathlib import Path
|
|
|
|
from loguru import logger
|
|
|
|
random.seed(1234567)
|
|
|
|
|
|
def run(args):
|
|
data_root = os.path.expanduser(args.data_root)
|
|
ratio = args.train_ratio
|
|
|
|
data_path = os.path.join(data_root, "img", "*", "*.png")
|
|
img_list = glob.glob(data_path, recursive=True)
|
|
sorted(img_list)
|
|
random.shuffle(img_list)
|
|
|
|
train_size = int(len(img_list) * ratio)
|
|
|
|
train_text_path = os.path.join(data_root, "train.txt")
|
|
with open(train_text_path, "w") as file:
|
|
for item in img_list[:train_size]:
|
|
parts = Path(item).parts
|
|
item = os.path.join(parts[-2], parts[-1])
|
|
file.write("%s\n" % item.split(".png")[0])
|
|
|
|
val_text_path = os.path.join(data_root, "val.txt")
|
|
with open(val_text_path, "w") as file:
|
|
for item in img_list[train_size:]:
|
|
parts = Path(item).parts
|
|
item = os.path.join(parts[-2], parts[-1])
|
|
file.write("%s\n" % item.split(".png")[0])
|
|
|
|
logger.info(f"TRAIN LABEL: {train_text_path}")
|
|
logger.info(f"VAL LABEL: {val_text_path}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument(
|
|
"--data_root",
|
|
type=str,
|
|
default="~/datasets/doc3d",
|
|
help="Data path to load data",
|
|
)
|
|
parser.add_argument(
|
|
"--train_ratio", type=float, default=0.8, help="Ratio of training data"
|
|
)
|
|
|
|
args = parser.parse_args()
|
|
|
|
logger.info(args)
|
|
|
|
run(args)
|