移动doc_dewarp
This commit is contained in:
48
services/paddle_services/doc_dewarp/plots.py
Normal file
48
services/paddle_services/doc_dewarp/plots.py
Normal file
@@ -0,0 +1,48 @@
|
||||
import copy
|
||||
import os
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import paddle.optimizer as optim
|
||||
|
||||
from GeoTr import GeoTr
|
||||
|
||||
|
||||
def plot_lr_scheduler(optimizer, scheduler, epochs=65, save_dir=""):
|
||||
"""
|
||||
Plot the learning rate scheduler
|
||||
"""
|
||||
|
||||
optimizer = copy.copy(optimizer)
|
||||
scheduler = copy.copy(scheduler)
|
||||
|
||||
lr = []
|
||||
for _ in range(epochs):
|
||||
for _ in range(30):
|
||||
lr.append(scheduler.get_lr())
|
||||
optimizer.step()
|
||||
scheduler.step()
|
||||
|
||||
epoch = [float(i) / 30.0 for i in range(len(lr))]
|
||||
|
||||
plt.figure()
|
||||
plt.plot(epoch, lr, ".-", label="Learning Rate")
|
||||
plt.xlabel("epoch")
|
||||
plt.ylabel("Learning Rate")
|
||||
plt.title("Learning Rate Scheduler")
|
||||
plt.savefig(os.path.join(save_dir, "lr_scheduler.png"), dpi=300)
|
||||
plt.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
model = GeoTr()
|
||||
|
||||
schaduler = optim.lr.OneCycleLR(
|
||||
max_learning_rate=1e-4,
|
||||
total_steps=1950,
|
||||
phase_pct=0.1,
|
||||
end_learning_rate=1e-4 / 2.5e5,
|
||||
)
|
||||
optimizer = optim.AdamW(learning_rate=schaduler, parameters=model.parameters())
|
||||
plot_lr_scheduler(
|
||||
scheduler=schaduler, optimizer=optimizer, epochs=65, save_dir="./"
|
||||
)
|
||||
Reference in New Issue
Block a user