koyu008's picture
Update app.py
a3af327 verified
from fastapi import FastAPI, Request
from pydantic import BaseModel
import torch
import torch.nn as nn
from transformers import DistilBertTokenizer, DistilBertModel, AutoModel, AutoTokenizer
from langdetect import detect
from huggingface_hub import snapshot_download
import os
# Device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Download model repos from HF Hub
english_repo = snapshot_download("koyu008/English_Toxic_Classifier")
hinglish_repo = snapshot_download("koyu008/HInglish_comment_classifier")
# Tokenizers
english_tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")
hinglish_tokenizer = AutoTokenizer.from_pretrained("xlm-roberta-base")
# English Model
class ToxicBERT(nn.Module):
def __init__(self):
super().__init__()
self.bert = DistilBertModel.from_pretrained("distilbert-base-uncased")
self.dropout = nn.Dropout(0.3)
self.classifier = nn.Linear(self.bert.config.hidden_size, 6)
def forward(self, input_ids, attention_mask):
output = self.bert(input_ids=input_ids, attention_mask=attention_mask).last_hidden_state[:, 0]
return self.classifier(self.dropout(output))
# Hinglish Model
class HinglishToxicClassifier(nn.Module):
def __init__(self):
super().__init__()
self.bert = AutoModel.from_pretrained("xlm-roberta-base")
hidden_size = self.bert.config.hidden_size
self.pool = lambda hidden: torch.cat([
hidden.mean(dim=1),
hidden.max(dim=1).values
], dim=1)
self.bottleneck = nn.Sequential(
nn.Linear(2 * hidden_size, hidden_size),
nn.ReLU(),
nn.Dropout(0.2)
)
self.classifier = nn.Linear(hidden_size, 2)
def forward(self, input_ids, attention_mask):
hidden = self.bert(input_ids=input_ids, attention_mask=attention_mask).last_hidden_state
pooled = self.pool(hidden)
x = self.bottleneck(pooled)
return self.classifier(x)
# Instantiate and load models
english_model = ToxicBERT().to(device)
english_model.load_state_dict(torch.load(os.path.join(english_repo, "bert_toxic_classifier.pt"), map_location=device))
english_model.eval()
hinglish_model = HinglishToxicClassifier().to(device)
hinglish_model.load_state_dict(torch.load(os.path.join(hinglish_repo, "best_hinglish_model.pt"), map_location=device))
hinglish_model.eval()
# Labels
english_labels = ['toxic', 'severe toxic', 'obscene', 'threat', 'insult', 'identity hate']
hinglish_labels = ['not toxic', 'toxic']
# FastAPI
app = FastAPI()
from fastapi.middleware.cors import CORSMiddleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # Or restrict to your frontend domain
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
class TextIn(BaseModel):
text: str
@app.post("/api/predict")
@app.post("/api/predict")
def predict(data: TextIn):
text = data.text
try:
lang = detect(text)
except:
lang = "unknown"
if lang == "en":
tokenizer = english_tokenizer
model = english_model
inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True).to(device)
with torch.no_grad():
outputs = model(**inputs)
probs = torch.sigmoid(outputs).squeeze().cpu().tolist()
return {"language": "English", "predictions": dict(zip(english_labels, probs))}
else:
tokenizer = hinglish_tokenizer
model = hinglish_model
inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True).to(device)
with torch.no_grad():
outputs = model(**inputs)
probs = torch.softmax(outputs, dim=1).squeeze().cpu().tolist()
return {"language": "Hinglish", "predictions": dict(zip(hinglish_labels, probs))}
@app.get("/")
def root():
return {"message": "API is running"}