import torch import torch.nn as nn import torch.optim as optim from torch.utils.data import Dataset, DataLoader from datasets import load_dataset from transformers import AutoTokenizer from tqdm import tqdm import math # 1. Dataset class for loading and processing data class FullChatDataset(Dataset): def __init__(self, dataset_names=["blended_skill_talk", "conv_ai_2", "social_i_qa"], max_length=128): self.datasets = [] # Load all specified datasets for name in dataset_names: try: dataset = load_dataset(name, split="train") self.datasets.append(dataset) except Exception as e: print(f"Failed to load dataset {name}: {e}") self.tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased") self.tokenizer.add_special_tokens({'pad_token': '[PAD]'}) self.max_length = max_length def __len__(self): return sum(len(d) for d in self.datasets) def __getitem__(self, idx): # Determine which dataset the index belongs to for dataset in self.datasets: if idx < len(dataset): item = dataset[idx] break idx -= len(dataset) # Handling different dataset formats if 'dialog' in item: # For Daily Dialog dialog = item['dialog'] elif 'messages' in item: # For some other datasets dialog = [msg['text'] for msg in item['messages']] else: # Universal handling dialog = [v for k, v in item.items() if isinstance(v, str)] context = " [SEP] ".join(dialog[:-1]) response = dialog[-1] inputs = self.tokenizer( context, text_pair=response, max_length=self.max_length, padding='max_length', truncation=True, return_tensors="pt" ) return { 'input_ids': inputs['input_ids'].flatten(), 'attention_mask': inputs['attention_mask'].flatten(), 'labels': inputs['input_ids'].flatten() } # 2. Model architecture class SimpleTransformerModel(nn.Module): def __init__(self, vocab_size, d_model=256, nhead=4, num_layers=3): super().__init__() self.embedding = nn.Embedding(vocab_size, d_model) self.pos_encoder = PositionalEncoding(d_model) encoder_layer = nn.TransformerEncoderLayer(d_model, nhead) self.transformer = nn.TransformerEncoder(encoder_layer, num_layers) self.fc = nn.Linear(d_model, vocab_size) def forward(self, x, mask=None): x = self.embedding(x) x = self.pos_encoder(x) x = self.transformer(x, mask) return self.fc(x) class PositionalEncoding(nn.Module): def __init__(self, d_model, max_len=500): super().__init__() position = torch.arange(max_len).unsqueeze(1) div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model)) pe = torch.zeros(max_len, d_model) pe[:, 0::2] = torch.sin(position * div_term) pe[:, 1::2] = torch.cos(position * div_term) self.register_buffer('pe', pe) def forward(self, x): return x + self.pe[:x.size(1)] # 3. Model training def train(model, dataloader, epochs=3, lr=3e-4): device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model = model.to(device) criterion = nn.CrossEntropyLoss(ignore_index=0) optimizer = optim.Adam(model.parameters(), lr=lr) for epoch in range(epochs): model.train() total_loss = 0 pbar = tqdm(dataloader, desc=f"Epoch {epoch+1}/{epochs}") for batch in pbar: inputs = batch['input_ids'].to(device) masks = batch['attention_mask'].to(device) labels = batch['labels'].to(device) optimizer.zero_grad() outputs = model(inputs, masks) loss = criterion(outputs.view(-1, outputs.size(-1)), labels.view(-1)) loss.backward() optimizer.step() total_loss += loss.item() pbar.set_postfix({'loss': loss.item()}) print(f"Epoch {epoch+1} - Avg loss: {total_loss/len(dataloader):.4f}") # 4. Response generation def chat(model, tokenizer, prompt, max_length=50): device = next(model.parameters()).device model.eval() inputs = tokenizer( prompt, return_tensors="pt", max_length=128, truncation=True, padding='max_length' ).to(device) with torch.no_grad(): outputs = model.generate( input_ids=inputs['input_ids'], attention_mask=inputs['attention_mask'], max_length=max_length, do_sample=True, top_k=50, top_p=0.95, temperature=0.7 ) return tokenizer.decode(outputs[0], skip_special_tokens=True) # 5. Main process if __name__ == "__main__": # Initialization dataset = FullChatDataset() dataloader = DataLoader(dataset, batch_size=16, shuffle=True) # Model creation model = SimpleTransformerModel(len(dataset.tokenizer)) # Training train(model, dataloader) # Saving torch.save(model.state_dict(), "chatbot_model.pt") dataset.tokenizer.save_pretrained("chatbot_tokenizer") while True: user_input = input("You: ") if user_input.lower() in ['exit', 'quit']: break response = chat(model, dataset.tokenizer, user_input) print(f"Bot: {response}")