tiny_llm / train.py
xcx0902's picture
Upload folder using huggingface_hub
493e08a verified
raw
history blame
2.27 kB
import torch
import torch.nn as nn
import torch.optim as optim
from SimpleRNN import SimpleRNN
import os
import json
from tqdm import tqdm, trange
import time
training_text = open("train_data.txt", encoding="utf-8").read()
chars = sorted(list(set(training_text))) # Unique characters
char_to_idx = {ch: i for i, ch in enumerate(chars)}
idx_to_char = {i: ch for i, ch in enumerate(chars)}
parameters = json.loads(open("parameter.json").read())
input_size = len(chars)
hidden_size = parameters["hidden_size"]
output_size = len(chars)
sequence_length = parameters["sequence_length"]
epochs = 1000
learning_rate = parameters["learning_rate"]
model_path = parameters["model_path"]
train_data = []
for i in range(len(training_text) - sequence_length):
input_seq = training_text[i : i + sequence_length]
target_char = training_text[i + sequence_length]
train_data.append((torch.tensor([char_to_idx[ch] for ch in input_seq]), char_to_idx[target_char]))
if os.path.exists(model_path):
model = torch.load(model_path, weights_only=False)
print("Loaded pre-trained model. Continue training...")
else:
print("Training new model...")
model = SimpleRNN(input_size, hidden_size, output_size)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
for epoch in range(epochs):
try:
total_loss = 0
hidden = torch.zeros(1, 1, hidden_size)
pbar = tqdm(train_data, desc=f"Epoch={epoch}, Loss=N/A")
count = 0
for input_seq, target in pbar:
count += 1
optimizer.zero_grad()
output, hidden = model(input_seq, hidden.detach())
loss = criterion(output, torch.tensor([target]))
loss.backward()
optimizer.step()
total_loss += loss.item()
pbar.desc = f"Epoch={epoch}, Loss={total_loss / count:.12f}"
pbar.close()
time.sleep(1)
except KeyboardInterrupt:
break
hidden = torch.zeros(1, 1, hidden_size)
output, hidden = model(input_seq, hidden.detach())
torch.save(model, model_path)
with open("vocab.json", "w") as f:
f.write(json.dumps(chars))
print("Model saved.")