Files
2024-09-24 17:10:56 +08:00

49 lines
1.2 KiB
Python

import copy
import os
import matplotlib.pyplot as plt
import paddle.optimizer as optim
from GeoTr import GeoTr
def plot_lr_scheduler(optimizer, scheduler, epochs=65, save_dir=""):
"""
Plot the learning rate scheduler
"""
optimizer = copy.copy(optimizer)
scheduler = copy.copy(scheduler)
lr = []
for _ in range(epochs):
for _ in range(30):
lr.append(scheduler.get_lr())
optimizer.step()
scheduler.step()
epoch = [float(i) / 30.0 for i in range(len(lr))]
plt.figure()
plt.plot(epoch, lr, ".-", label="Learning Rate")
plt.xlabel("epoch")
plt.ylabel("Learning Rate")
plt.title("Learning Rate Scheduler")
plt.savefig(os.path.join(save_dir, "lr_scheduler.png"), dpi=300)
plt.close()
if __name__ == "__main__":
model = GeoTr()
schaduler = optim.lr.OneCycleLR(
max_learning_rate=1e-4,
total_steps=1950,
phase_pct=0.1,
end_learning_rate=1e-4 / 2.5e5,
)
optimizer = optim.AdamW(learning_rate=schaduler, parameters=model.parameters())
plot_lr_scheduler(
scheduler=schaduler, optimizer=optimizer, epochs=65, save_dir="./"
)