import torch import torch.nn as nn import torch.optim as optim from SimpleRNN import SimpleRNN import os import json from tqdm import tqdm, trange import time training_text = open("train_data.txt", encoding="utf-8").read() chars = sorted(list(set(training_text))) # Unique characters char_to_idx = {ch: i for i, ch in enumerate(chars)} idx_to_char = {i: ch for i, ch in enumerate(chars)} parameters = json.loads(open("parameter.json").read()) input_size = len(chars) hidden_size = parameters["hidden_size"] output_size = len(chars) sequence_length = parameters["sequence_length"] epochs = 1000 learning_rate = parameters["learning_rate"] model_path = parameters["model_path"] train_data = [] for i in range(len(training_text) - sequence_length): input_seq = training_text[i : i + sequence_length] target_char = training_text[i + sequence_length] train_data.append((torch.tensor([char_to_idx[ch] for ch in input_seq]), char_to_idx[target_char])) if os.path.exists(model_path): model = torch.load(model_path, weights_only=False) print("Loaded pre-trained model. Continue training...") else: print("Training new model...") model = SimpleRNN(input_size, hidden_size, output_size) criterion = nn.CrossEntropyLoss() optimizer = optim.Adam(model.parameters(), lr=learning_rate) for epoch in range(epochs): try: total_loss = 0 hidden = torch.zeros(1, 1, hidden_size) pbar = tqdm(train_data, desc=f"Epoch={epoch}, Loss=N/A") count = 0 for input_seq, target in pbar: count += 1 optimizer.zero_grad() output, hidden = model(input_seq, hidden.detach()) loss = criterion(output, torch.tensor([target])) loss.backward() optimizer.step() total_loss += loss.item() pbar.desc = f"Epoch={epoch}, Loss={total_loss / count:.12f}" pbar.close() time.sleep(1) except KeyboardInterrupt: break hidden = torch.zeros(1, 1, hidden_size) output, hidden = model(input_seq, hidden.detach()) torch.save(model, model_path) with open("vocab.json", "w") as f: f.write(json.dumps(chars)) print("Model saved.")