tiny_llm / train.py
xcx0902's picture
Upload folder using huggingface_hub
b9d1833 verified
raw
history blame
2.96 kB
import torch
import torch.nn as nn
import torch.optim as optim
import os
import json
from tqdm import tqdm, trange
import time
# Generate simple training data
training_text = open("train_data.txt", encoding="utf-8").read()
chars = sorted(list(set(training_text))) # Unique characters
char_to_idx = {ch: i for i, ch in enumerate(chars)}
idx_to_char = {i: ch for i, ch in enumerate(chars)}
# Model parameters
parameters = json.loads(open("parameter.json").read())
input_size = len(chars)
hidden_size = parameters["hidden_size"]
output_size = len(chars)
sequence_length = parameters["sequence_length"]
epochs = 1000
learning_rate = parameters["learning_rate"]
model_path = parameters["model_path"]
# Create training data (input-output pairs)
train_data = []
for i in range(len(training_text) - sequence_length):
input_seq = training_text[i : i + sequence_length]
target_char = training_text[i + sequence_length]
train_data.append((torch.tensor([char_to_idx[ch] for ch in input_seq]), char_to_idx[target_char]))
# Define the simple RNN model
class SimpleRNN(nn.Module):
def __init__(self, input_size, hidden_size, output_size):
super(SimpleRNN, self).__init__()
self.hidden_size = hidden_size
self.rnn = nn.RNN(input_size, hidden_size, batch_first=True)
self.fc = nn.Linear(hidden_size, output_size)
def forward(self, x, hidden):
x = torch.nn.functional.one_hot(x, num_classes=input_size).float()
out, hidden = self.rnn(x.unsqueeze(0), hidden)
out = self.fc(out[:, -1, :]) # Take last time step's output
return out, hidden
if os.path.exists(model_path):
model = torch.load(model_path, weights_only=False)
print("Loaded pre-trained model. Continue training...")
else:
print("Training new model...")
model = SimpleRNN(input_size, hidden_size, output_size)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
for epoch in range(epochs):
try:
total_loss = 0
hidden = torch.zeros(1, 1, hidden_size)
pbar = tqdm(train_data, desc=f"Epoch={epoch}, Loss=N/A")
count = 0
for input_seq, target in pbar:
count += 1
optimizer.zero_grad()
output, hidden = model(input_seq, hidden.detach())
loss = criterion(output, torch.tensor([target]))
loss.backward()
optimizer.step()
total_loss += loss.item()
pbar.desc = f"Epoch={epoch}, Loss={total_loss / count:.12f}"
pbar.close()
time.sleep(1)
except KeyboardInterrupt:
break
hidden = torch.zeros(1, 1, hidden_size)
output, hidden = model(input_seq, hidden.detach())
# Save the trained model
torch.save(model, model_path)
with open("vocab.json", "w") as f:
f.write(json.dumps(chars))
print("Model saved.")