| | import os |
| | import torch |
| | from logger import log_data, init_logger, log_img |
| | import torch.nn as nn |
| | from tqdm import tqdm, trange |
| | from torch.profiler import profile, record_function, ProfilerActivity |
| | import gc |
| | import numpy as np |
| | from eval import evaluate_topk |
| | from dataset import dataset |
| | from Levenshtein import ratio |
| | from enum import Enum |
| | import signal |
| | import sys |
| |
|
| | device = "mps" if torch.backends.mps.is_available() else "cpu" |
| |
|
| |
|
| | from collections import defaultdict |
| |
|
| |
|
| | class ValueTracker: |
| | def __init__(self): |
| | self.data = {} |
| |
|
| | def add(self, label, value): |
| | if label not in self.data: |
| | self.data[label] = [] |
| | self.data[label].append(value) |
| |
|
| | def average(self, label): |
| | values = self.data[label] |
| | if values: |
| | return sum(values) / len(values) |
| | else: |
| | return 0.0 |
| |
|
| | def reset(self, label=None): |
| | if label is not None: |
| | if label in self.data: |
| | self.data[label] = [] |
| | else: |
| | self.data = {} |
| |
|
| | def get_values(self, label): |
| | return self.data[label] |
| |
|
| | def summary(self): |
| | for label in self.data: |
| | avg = self.average(label) |
| | print(f"{label} - Average: {avg:.4f}") |
| |
|
| |
|
| | class TrainingManager: |
| | def __init__( |
| | self, |
| | net: nn.Module, |
| | dir: str, |
| | dataloader, |
| | device=device, |
| | trainstep_checkin_interval=100, |
| | epochs=100, |
| | val_dataloader=None, |
| | ): |
| |
|
| | learning_rate = 0.001 |
| |
|
| | self.clip = 1.0 |
| |
|
| | self.trainstep_checkin_interval = trainstep_checkin_interval |
| | self.epochs = epochs |
| |
|
| | self.dataloader = dataloader |
| | self.val_dataloader = val_dataloader |
| |
|
| | self.net = net |
| | self.net.to(device) |
| | self.device = device |
| |
|
| | self.dir = dir |
| |
|
| | self.criterion = torch.nn.CrossEntropyLoss(label_smoothing=0.1) |
| | self.optimizer = torch.optim.AdamW( |
| | self.net.parameters(), lr=learning_rate |
| | ) |
| |
|
| | |
| | |
| | self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( |
| | optimizer=self.optimizer, factor=0.9, patience=10 |
| | ) |
| |
|
| | self.tracker = ValueTracker() |
| |
|
| | self.resume_epoch, self.resume_step = self.get_resume() |
| | if self.resume_epoch >= self.epochs - 1: |
| | pass |
| | elif self.resume_epoch != 0 or self.resume_step != 0: |
| | self.resume() |
| | else: |
| | if os.path.exists(self.dir) and any( |
| | os.path.isfile(os.path.join(self.dir, item)) |
| | for item in os.listdir(self.dir) |
| | ): |
| | raise ValueError(f"The directory '{self.dir}' contains files!") |
| |
|
| | os.makedirs(self.dir, exist_ok=True) |
| | os.makedirs(os.path.join(self.dir, "ckpt"), exist_ok=True) |
| |
|
| | print(f"{self.get_param_count()} parameters.") |
| | |
| | |
| | signal.signal(signal.SIGINT, self._signal_handler) |
| | self._interrupted = False |
| |
|
| | def _signal_handler(self, signum, frame): |
| | """Handle keyboard interrupt gracefully""" |
| | print("\nKeyboard interrupt received. Saving checkpoint...") |
| | self._interrupted = True |
| |
|
| | def _save_on_interrupt(self, epoch, step): |
| | """Save checkpoint and resume info on interrupt""" |
| | try: |
| | self._save("latest.pt") |
| | self.write_resume(epoch, step) |
| | print(f"Checkpoint saved at epoch {epoch}, step {step}") |
| | except Exception as e: |
| | print(f"Failed to save checkpoint: {e}") |
| | finally: |
| | print("Exiting...") |
| | sys.exit(0) |
| |
|
| | def hasnan(self): |
| | for _, param in self.net.named_parameters(): |
| | if torch.isnan(param).any(): |
| | return True |
| | for _, param in self.net.named_parameters(): |
| | if param.grad is not None and torch.isnan(param.grad).any(): |
| | return True |
| |
|
| | return False |
| |
|
| | def _save(self, name="latest.pt"): |
| | with open(os.path.join(self.dir, "ckpt", name), "wb+") as f: |
| | torch.save(self.net.state_dict(), f) |
| |
|
| | def _load(self, name="latest.pt"): |
| | self.net.load_state_dict( |
| | torch.load(os.path.join(self.dir, "ckpt", name), weights_only=True) |
| | ) |
| |
|
| | def write_resume(self, epoch, step=0): |
| | with open(os.path.join(self.dir, "ckpt", "resume.txt"), "w+") as f: |
| | f.write(f"{epoch},{step}") |
| |
|
| | def get_resume(self): |
| | try: |
| | with open(os.path.join(self.dir, "ckpt", "resume.txt"), "r") as f: |
| | content = f.read().strip() |
| | if ',' in content: |
| | epoch, step = content.split(',') |
| | return int(epoch), int(step) |
| | else: |
| | |
| | return int(content), 0 |
| | except (FileNotFoundError, ValueError): |
| | return 0, 0 |
| |
|
| | def write_best_val_loss(self, loss): |
| | with open(os.path.join(self.dir, "ckpt", "best_val_loss.txt"), "w+") as f: |
| | f.write(f"{loss:.6f}") |
| |
|
| | def get_best_val_loss(self): |
| | try: |
| | with open(os.path.join(self.dir, "ckpt", "best_val_loss.txt"), "r") as f: |
| | return float(f.read()) |
| | except (FileNotFoundError, ValueError): |
| | return float("inf") |
| |
|
| | def resume(self): |
| | self._load("latest.pt") |
| |
|
| | def save(self, loss): |
| | self._save("latest.pt") |
| |
|
| | best_val_loss = self.get_best_val_loss() |
| | if loss < best_val_loss: |
| | best_val_loss = loss |
| | self._save("best.pt") |
| | self.write_best_val_loss(best_val_loss) |
| |
|
| | |
| |
|
| | def on_trainloop_checkin(self, epoch, step, dataloader_len): |
| | if self.hasnan(): |
| | |
| | print("RESUMING") |
| | self.resume() |
| |
|
| | self._save("latest.pt") |
| | self.write_resume(epoch, step + 1) |
| |
|
| | log_data( |
| | {"Loss/Trainstep": self.tracker.average("Loss/trainstep")}, |
| | epoch * dataloader_len + step, |
| | ) |
| | log_data( |
| | {"Acc/Trainstep": self.tracker.average("Acc/trainstep")}, |
| | epoch * dataloader_len + step, |
| | ) |
| | log_data( |
| | {"TopKAcc/Trainstep": self.tracker.average("TopKAcc/trainstep")}, |
| | epoch * dataloader_len + step, |
| | ) |
| |
|
| | self.tracker.reset("Loss/trainstep") |
| | self.tracker.reset("Acc/trainstep") |
| | self.tracker.reset("TopKAcc/trainstep") |
| |
|
| | def on_epoch_checkin(self, epoch): |
| | if self.hasnan(): |
| | |
| | self.resume() |
| |
|
| | val_loss = float("inf") |
| | try: |
| | val_loss = self.tracker.average("Loss/val/epoch") |
| | except KeyError: |
| | pass |
| |
|
| | self.save( |
| | val_loss if val_loss < float("inf") else self.tracker.average("Loss/epoch") |
| | ) |
| |
|
| | log_data( |
| | { |
| | "Loss/Epoch": self.tracker.average("Loss/epoch"), |
| | "Loss/Val/Epoch": val_loss, |
| | "Perplexity/Val/Epoch": float(np.exp(val_loss)), |
| | "TopKAcc/Epoch": self.tracker.average("TopKAcc/epoch"), |
| | }, |
| | epoch, |
| | ) |
| |
|
| | self.tracker.reset("Acc/epoch") |
| | self.tracker.reset("Loss/epoch") |
| | self.tracker.reset("Loss/val/epoch") |
| | self.tracker.reset("TopKAcc/epoch") |
| | self.tracker.reset("Perplexity/val/epoch") |
| |
|
| | self.write_resume(epoch + 1, 0) |
| |
|
| | def eval_model(self, data, compute_metrics=True): |
| | if type(data) == tuple or type(data) == list: |
| | data = tuple(d.to(self.device) for d in data) |
| | batch, attn_mask = data |
| | else: |
| | data = data.to(self.device) |
| | batch = data |
| | attn_mask = None |
| |
|
| | del attn_mask |
| |
|
| | labels = batch[:, 1:].contiguous() |
| | batch = batch[:, :-1].contiguous() |
| |
|
| | |
| | results = self.net(batch, transpose=True) |
| | results = results.transpose(0, 1) |
| |
|
| | |
| | loss = self.criterion(results.reshape(-1, results.size(-1)), labels.reshape(-1)) |
| |
|
| | if not compute_metrics: |
| | return loss, None, None |
| | |
| | |
| | preds = results.reshape(-1, results.size(-1)).argmax(dim=1) |
| | labels_flat = labels.reshape(-1) |
| | acc = (preds == labels_flat).float().mean() |
| |
|
| | |
| | top_k = 5 |
| | top_k_preds = results.reshape(-1, results.size(-1)).topk(top_k, dim=1).indices |
| | top_k_acc = (top_k_preds == labels_flat.unsqueeze(1)).any(dim=1).float().mean().item() |
| |
|
| | return loss, acc, top_k_acc |
| |
|
| | def run_generation(self, data): |
| | batch, attn_mask = data |
| | start_sequence = batch[:, :-1].contiguous()[0][:100].unsqueeze(0) |
| | result = evaluate_topk( |
| | self.net, start_sequence, amt=100, k=10, temperature=0.8, device=device |
| | ) |
| |
|
| | result = dataset.manager.decode(result[0]) |
| | batch_str = dataset.manager.decode(start_sequence[0]) |
| |
|
| | result = f"<data>{batch_str}</data>{result[len(batch_str):]}" |
| | |
| |
|
| | with open(os.path.join(self.dir, "ckpt", "generated.txt"), "a+") as f: |
| | f.write(f"K=10,T=0.8: {result}\n") |
| |
|
| | def epoch_gen(self, loader): |
| | if loader is not None: |
| | for data in loader: |
| | self.run_generation(data) |
| | break |
| |
|
| | def trainstep(self, data): |
| | self.optimizer.zero_grad() |
| |
|
| | loss, acc, topk_acc = self.eval_model(data) |
| |
|
| | self.tracker.add("Loss/trainstep", loss.item()) |
| | self.tracker.add("Loss/epoch", loss.item()) |
| |
|
| | self.tracker.add("Acc/trainstep", acc.item()) |
| | self.tracker.add("TopKAcc/trainstep", topk_acc) |
| | self.tracker.add("TopKAcc/epoch", topk_acc) |
| |
|
| | loss.backward() |
| | self.optimizer.step() |
| |
|
| | return loss.detach(), acc.detach() |
| |
|
| | @torch.no_grad() |
| | def valstep(self, data): |
| | loss, acc, topk_acc = self.eval_model(data) |
| |
|
| | self.tracker.add("Loss/valstep", loss.item()) |
| | self.tracker.add("Loss/val/epoch", loss.item()) |
| |
|
| | self.tracker.add("Perplexity/val/epoch", float(np.exp(loss.item()))) |
| |
|
| | self.tracker.add("TopKAcc/valstep", topk_acc) |
| | self.tracker.add("TopKAcc/val/epoch", topk_acc) |
| |
|
| | return loss.detach(), acc.detach() |
| |
|
| | def val_loop(self, val_loader): |
| | if val_loader is not None: |
| | for step, data in enumerate( |
| | test_tqdm := tqdm( |
| | val_loader, leave=False, dynamic_ncols=True, desc=f"valloop" |
| | ) |
| | ): |
| | self.valstep(data) |
| | avg_val_loss = self.tracker.average("Loss/val/epoch") |
| | test_tqdm.set_postfix({"Val Loss": f"{avg_val_loss:.3f}"}) |
| |
|
| | def train_loop(self, dataloader, epoch): |
| | start_step = self.resume_step if epoch == self.resume_epoch else 0 |
| | |
| | for step, data in enumerate( |
| | train_tqdm := tqdm( |
| | dataloader, leave=False, dynamic_ncols=True, desc=f"trainloop" |
| | ) |
| | ): |
| | |
| | if self._interrupted: |
| | self._save_on_interrupt(epoch, step) |
| | raise KeyboardInterrupt("Training interrupted by user") |
| | |
| | |
| | if step < start_step: |
| | continue |
| | |
| | self.trainstep(data) |
| |
|
| | avg_train_loss = self.tracker.average("Loss/trainstep") |
| | train_tqdm.set_postfix({"Train Loss": f"{avg_train_loss:.3f}"}) |
| |
|
| | if ( |
| | step % self.trainstep_checkin_interval |
| | == self.trainstep_checkin_interval - 1 |
| | ): |
| | |
| | self.on_trainloop_checkin(epoch, step, len(dataloader)) |
| | |
| |
|
| | def epoch(self, epoch: int, dataloader, val_loader=None): |
| | if self._interrupted: |
| | return |
| | |
| | self.net.train() |
| | self.train_loop(dataloader, epoch) |
| | |
| | if self._interrupted: |
| | return |
| | |
| | tqdm.write(self.get_memory_stats(self.net, dataloader.dataset, sep=" / ")) |
| | self.net.eval() |
| | self.val_loop(val_loader) |
| |
|
| | if self._interrupted: |
| | return |
| | |
| | self.epoch_gen(val_loader) |
| | self.on_epoch_checkin(epoch) |
| |
|
| | def train(self, epochs=None, dataloader=None): |
| |
|
| | if epochs is not None: |
| | self.epochs = epochs |
| |
|
| | if dataloader is not None: |
| | self.dataloader = dataloader |
| |
|
| | try: |
| | for e in trange( |
| | self.resume_epoch, self.epochs, dynamic_ncols=True, unit_scale=True, unit_divisor=60 |
| | ): |
| | if self._interrupted: |
| | break |
| | |
| | self.epoch(e, self.dataloader, self.val_dataloader) |
| |
|
| | except KeyboardInterrupt: |
| | print("\nTraining interrupted. Checkpoint saved.") |
| | finally: |
| | print("Training session ended.") |
| | gc.collect() |
| | os.system( |
| | """osascript -e 'display notification "Training complete" with title "Training Complete"'""" |
| | ) |
| |
|
| | @staticmethod |
| | def get_curriculum_enum(): |
| | return Enum( |
| | "Curriculum", |
| | [ |
| | ("NOOP", 1), |
| | ("CURRICULUM", 2), |
| | ("ANTICURRICULUM", 3), |
| | ("SEQUENTIAL", 4), |
| | ("HYBRID", 5), |
| | ], |
| | ) |
| |
|
| | def train_curriculum( |
| | self, epochs=None, dataloader=None, curriculum_type=None, loss_based=False |
| | ): |
| |
|
| | print(f"Training curriculum: {curriculum_type} loss_based: {loss_based}") |
| |
|
| | Curriculum = self.get_curriculum_enum() |
| |
|
| | if curriculum_type is None: |
| | curriculum_type = Curriculum.NOOP |
| |
|
| | if epochs is not None: |
| | self.epochs = epochs |
| |
|
| | if dataloader is not None: |
| | self.dataloader = dataloader |
| |
|
| | sorted_indices = sorted( |
| | range(len(self.dataloader.dataset)), |
| | key=lambda i: self.dataloader.dataset[i][1], |
| | reverse=(curriculum_type.value == Curriculum.ANTICURRICULUM.value), |
| | ) |
| |
|
| | |
| | standard_schedule = [ |
| | min(1.0, ((i + 2) - (i % 2)) / self.epochs) for i in range(self.epochs) |
| | ] |
| | hybrid_schedule = [ |
| | min(1.0, (i + 2) / self.epochs) for i in range(self.epochs) |
| | ] |
| | step_size = 1 / (self.epochs / 2) |
| |
|
| | try: |
| | for e in trange( |
| | self.resume_epoch, self.epochs, dynamic_ncols=True, unit_scale=True, unit_divisor=60 |
| | ): |
| |
|
| | if loss_based: |
| | sorted_indices = self.get_loss_based_indices( |
| | self.dataloader, |
| | anti=(curriculum_type.value == Curriculum.ANTICURRICULUM.value), |
| | ) |
| |
|
| | subset_indices = None |
| | if curriculum_type.value == Curriculum.NOOP.value: |
| | print("No curriculum") |
| | subset_indices = sorted_indices |
| | elif curriculum_type.value == Curriculum.SEQUENTIAL.value: |
| | print("Sequential curriculum") |
| | subset_indices = sorted_indices[ |
| | int( |
| | max(len(sorted_indices) * (standard_schedule[e] - step_size), 0) |
| | ) : int(len(sorted_indices) * standard_schedule[e]) |
| | ] |
| | elif curriculum_type.value == Curriculum.HYBRID.value: |
| | print("Hybrid curriculum") |
| | subset_indices = sorted_indices[ |
| | int( |
| | max(len(sorted_indices) * (hybrid_schedule[e] - step_size), 0) |
| | ) : int(len(sorted_indices) * hybrid_schedule[e]) |
| | ] |
| | elif curriculum_type.value == Curriculum.CURRICULUM.value: |
| | print("Curriculum") |
| | subset_indices = sorted_indices[ |
| | : int(len(sorted_indices) * standard_schedule[e]) |
| | ] |
| | elif curriculum_type.value == Curriculum.ANTICURRICULUM.value: |
| | print("Anti curriculum") |
| | subset_indices = sorted_indices[ |
| | : int(len(sorted_indices) * standard_schedule[e]) |
| | ] |
| | else: |
| | raise ValueError(f"Unknown curriculum type: {curriculum_type}") |
| |
|
| | subset = torch.utils.data.Subset(self.dataloader.dataset, subset_indices) |
| | cur_dataloader = torch.utils.data.DataLoader( |
| | subset, batch_size=self.dataloader.batch_size, shuffle=True |
| | ) |
| |
|
| | self.epoch(e, cur_dataloader, self.val_dataloader) |
| |
|
| | except KeyboardInterrupt: |
| | print("\nCurriculum training interrupted. Checkpoint saved.") |
| | finally: |
| | print("Curriculum training session ended.") |
| | gc.collect() |
| | os.system( |
| | """osascript -e 'display notification "Training complete" with title "Training Complete"'""" |
| | ) |
| |
|
| | print("All done!") |
| | gc.collect() |
| | os.system( |
| | """osascript -e 'display notification "Training complete" with title "Training Complete"'""" |
| | ) |
| |
|
| | def get_loss_based_indices(self, dataloader, anti=False): |
| | losses = [] |
| | |
| | temp_dataloader = torch.utils.data.DataLoader( |
| | dataloader.dataset, |
| | batch_size=dataloader.batch_size, |
| | shuffle=False, |
| | num_workers=( |
| | dataloader.num_workers if hasattr(dataloader, "num_workers") else 0 |
| | ), |
| | ) |
| |
|
| | with torch.no_grad(): |
| | for batch, _ in tqdm( |
| | temp_dataloader, |
| | dynamic_ncols=True, |
| | leave=False, |
| | desc="Loss-based sorting", |
| | ): |
| | loss, _, _ = self.eval_model(batch, compute_metrics=False) |
| | |
| | |
| | if isinstance(loss, torch.Tensor) and loss.dim() == 0: |
| | losses.extend([loss.item()] * batch.size(0)) |
| | else: |
| | |
| | losses.extend(loss.detach().cpu().tolist()) |
| |
|
| | sorted_indices = sorted( |
| | range(len(dataloader.dataset)), key=lambda i: losses[i], reverse=anti |
| | ) |
| | return sorted_indices |
| |
|
| | def nan_debug(self): |
| | torch.autograd.set_detect_anomaly(True) |
| |
|
| | def forward_hook(module, input, output): |
| | if isinstance(output, tuple): |
| | return |
| | if torch.isnan(output).any() or torch.isinf(output).any(): |
| | print(f"NaNs/Infs detected in {module}") |
| |
|
| | for module in self.net.modules(): |
| | module.register_forward_hook(forward_hook) |
| | self.val_loop(self.val_dataloader) |
| |
|
| | def get_param_count(self): |
| | return sum(p.numel() for p in self.net.parameters()) |
| |
|
| | def profile_trainstep(self): |
| |
|
| | self.net.train() |
| | data = next(iter(self.dataloader)) |
| |
|
| | |
| | with profile(activities=[ProfilerActivity.CPU], record_shapes=True) as prof: |
| | with record_function("train_step"): |
| | self.trainstep(data) |
| |
|
| | print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=10)) |
| |
|
| | @staticmethod |
| | def get_memory_stats(net, trainset, sep="\n"): |
| | result = "" |
| | import datetime |
| | import time |
| | result += f"Time: {datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}" + sep |
| | import psutil |
| | if torch.backends.mps.is_available(): |
| | result += f"MPS: {torch.mps.current_allocated_memory()/1e9:.2f} GB" + sep |
| | result += f"RAM: {psutil.virtual_memory().percent}% used" + sep |
| | |
| | |
| | chunks = getattr(trainset, 'chunks', getattr(trainset.dataset, 'chunks', None)) |
| | |
| | if chunks is not None: |
| | result += f"data: {sum(p.numel() * p.element_size() for p in [chunks]) / 1e9:.2f} GB" + sep |
| | |
| | |
| | model_size = sum(p.numel() * p.element_size() for p in net.parameters()) / 1e9 |
| | result += f"Params: {model_size:.2f} GB" + sep |
| | |
| | |
| | optimizer_size = model_size * 2 |
| | result += f"Optim (est): {optimizer_size:.2f} GB" + sep |
| |
|
| | return result |
| |
|