移动doc_dewarp
This commit is contained in:
57
services/paddle_services/doc_dewarp/split_dataset.py
Normal file
57
services/paddle_services/doc_dewarp/split_dataset.py
Normal file
@@ -0,0 +1,57 @@
|
||||
import argparse
|
||||
import glob
|
||||
import os
|
||||
import random
|
||||
from pathlib import Path
|
||||
|
||||
from loguru import logger
|
||||
|
||||
random.seed(1234567)
|
||||
|
||||
|
||||
def run(args):
|
||||
data_root = os.path.expanduser(args.data_root)
|
||||
ratio = args.train_ratio
|
||||
|
||||
data_path = os.path.join(data_root, "img", "*", "*.png")
|
||||
img_list = glob.glob(data_path, recursive=True)
|
||||
sorted(img_list)
|
||||
random.shuffle(img_list)
|
||||
|
||||
train_size = int(len(img_list) * ratio)
|
||||
|
||||
train_text_path = os.path.join(data_root, "train.txt")
|
||||
with open(train_text_path, "w") as file:
|
||||
for item in img_list[:train_size]:
|
||||
parts = Path(item).parts
|
||||
item = os.path.join(parts[-2], parts[-1])
|
||||
file.write("%s\n" % item.split(".png")[0])
|
||||
|
||||
val_text_path = os.path.join(data_root, "val.txt")
|
||||
with open(val_text_path, "w") as file:
|
||||
for item in img_list[train_size:]:
|
||||
parts = Path(item).parts
|
||||
item = os.path.join(parts[-2], parts[-1])
|
||||
file.write("%s\n" % item.split(".png")[0])
|
||||
|
||||
logger.info(f"TRAIN LABEL: {train_text_path}")
|
||||
logger.info(f"VAL LABEL: {val_text_path}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--data_root",
|
||||
type=str,
|
||||
default="~/datasets/doc3d",
|
||||
help="Data path to load data",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--train_ratio", type=float, default=0.8, help="Ratio of training data"
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
logger.info(args)
|
||||
|
||||
run(args)
|
||||
Reference in New Issue
Block a user