from fastapi import FastAPI, Request from pydantic import BaseModel import torch import torch.nn as nn from transformers import DistilBertTokenizer, DistilBertModel, AutoModel, AutoTokenizer from langdetect import detect from huggingface_hub import snapshot_download import os # Device device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Download model repos from HF Hub english_repo = snapshot_download("koyu008/English_Toxic_Classifier") hinglish_repo = snapshot_download("koyu008/HInglish_comment_classifier") # Tokenizers english_tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased") hinglish_tokenizer = AutoTokenizer.from_pretrained("xlm-roberta-base") # English Model class ToxicBERT(nn.Module): def __init__(self): super().__init__() self.bert = DistilBertModel.from_pretrained("distilbert-base-uncased") self.dropout = nn.Dropout(0.3) self.classifier = nn.Linear(self.bert.config.hidden_size, 6) def forward(self, input_ids, attention_mask): output = self.bert(input_ids=input_ids, attention_mask=attention_mask).last_hidden_state[:, 0] return self.classifier(self.dropout(output)) # Hinglish Model class HinglishToxicClassifier(nn.Module): def __init__(self): super().__init__() self.bert = AutoModel.from_pretrained("xlm-roberta-base") hidden_size = self.bert.config.hidden_size self.pool = lambda hidden: torch.cat([ hidden.mean(dim=1), hidden.max(dim=1).values ], dim=1) self.bottleneck = nn.Sequential( nn.Linear(2 * hidden_size, hidden_size), nn.ReLU(), nn.Dropout(0.2) ) self.classifier = nn.Linear(hidden_size, 2) def forward(self, input_ids, attention_mask): hidden = self.bert(input_ids=input_ids, attention_mask=attention_mask).last_hidden_state pooled = self.pool(hidden) x = self.bottleneck(pooled) return self.classifier(x) # Instantiate and load models english_model = ToxicBERT().to(device) english_model.load_state_dict(torch.load(os.path.join(english_repo, "bert_toxic_classifier.pt"), map_location=device)) english_model.eval() hinglish_model = HinglishToxicClassifier().to(device) hinglish_model.load_state_dict(torch.load(os.path.join(hinglish_repo, "best_hinglish_model.pt"), map_location=device)) hinglish_model.eval() # Labels english_labels = ['toxic', 'severe toxic', 'obscene', 'threat', 'insult', 'identity hate'] hinglish_labels = ['not toxic', 'toxic'] # FastAPI app = FastAPI() from fastapi.middleware.cors import CORSMiddleware app.add_middleware( CORSMiddleware, allow_origins=["*"], # Or restrict to your frontend domain allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) class TextIn(BaseModel): text: str @app.post("/api/predict") @app.post("/api/predict") def predict(data: TextIn): text = data.text try: lang = detect(text) except: lang = "unknown" if lang == "en": tokenizer = english_tokenizer model = english_model inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True).to(device) with torch.no_grad(): outputs = model(**inputs) probs = torch.sigmoid(outputs).squeeze().cpu().tolist() return {"language": "English", "predictions": dict(zip(english_labels, probs))} else: tokenizer = hinglish_tokenizer model = hinglish_model inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True).to(device) with torch.no_grad(): outputs = model(**inputs) probs = torch.softmax(outputs, dim=1).squeeze().cpu().tolist() return {"language": "Hinglish", "predictions": dict(zip(hinglish_labels, probs))} @app.get("/") def root(): return {"message": "API is running"}