|
from fastapi import FastAPI, Request |
|
from pydantic import BaseModel |
|
import torch |
|
import torch.nn as nn |
|
from transformers import DistilBertTokenizer, DistilBertModel, AutoModel, AutoTokenizer |
|
from langdetect import detect |
|
from huggingface_hub import snapshot_download |
|
import os |
|
|
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
english_repo = snapshot_download("koyu008/English_Toxic_Classifier") |
|
hinglish_repo = snapshot_download("koyu008/HInglish_comment_classifier") |
|
|
|
|
|
english_tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased") |
|
hinglish_tokenizer = AutoTokenizer.from_pretrained("xlm-roberta-base") |
|
|
|
|
|
|
|
class ToxicBERT(nn.Module): |
|
def __init__(self): |
|
super().__init__() |
|
self.bert = DistilBertModel.from_pretrained("distilbert-base-uncased") |
|
self.dropout = nn.Dropout(0.3) |
|
self.classifier = nn.Linear(self.bert.config.hidden_size, 6) |
|
|
|
def forward(self, input_ids, attention_mask): |
|
output = self.bert(input_ids=input_ids, attention_mask=attention_mask).last_hidden_state[:, 0] |
|
return self.classifier(self.dropout(output)) |
|
|
|
|
|
|
|
class HinglishToxicClassifier(nn.Module): |
|
def __init__(self): |
|
super().__init__() |
|
self.bert = AutoModel.from_pretrained("xlm-roberta-base") |
|
hidden_size = self.bert.config.hidden_size |
|
self.pool = lambda hidden: torch.cat([ |
|
hidden.mean(dim=1), |
|
hidden.max(dim=1).values |
|
], dim=1) |
|
self.bottleneck = nn.Sequential( |
|
nn.Linear(2 * hidden_size, hidden_size), |
|
nn.ReLU(), |
|
nn.Dropout(0.2) |
|
) |
|
self.classifier = nn.Linear(hidden_size, 2) |
|
|
|
def forward(self, input_ids, attention_mask): |
|
hidden = self.bert(input_ids=input_ids, attention_mask=attention_mask).last_hidden_state |
|
pooled = self.pool(hidden) |
|
x = self.bottleneck(pooled) |
|
return self.classifier(x) |
|
|
|
|
|
|
|
english_model = ToxicBERT().to(device) |
|
english_model.load_state_dict(torch.load(os.path.join(english_repo, "bert_toxic_classifier.pt"), map_location=device)) |
|
english_model.eval() |
|
|
|
hinglish_model = HinglishToxicClassifier().to(device) |
|
hinglish_model.load_state_dict(torch.load(os.path.join(hinglish_repo, "best_hinglish_model.pt"), map_location=device)) |
|
hinglish_model.eval() |
|
|
|
|
|
english_labels = ['toxic', 'severe toxic', 'obscene', 'threat', 'insult', 'identity hate'] |
|
hinglish_labels = ['not toxic', 'toxic'] |
|
|
|
|
|
app = FastAPI() |
|
|
|
from fastapi.middleware.cors import CORSMiddleware |
|
|
|
app.add_middleware( |
|
CORSMiddleware, |
|
allow_origins=["*"], |
|
allow_credentials=True, |
|
allow_methods=["*"], |
|
allow_headers=["*"], |
|
) |
|
|
|
|
|
class TextIn(BaseModel): |
|
text: str |
|
|
|
|
|
@app.post("/api/predict") |
|
@app.post("/api/predict") |
|
def predict(data: TextIn): |
|
text = data.text |
|
try: |
|
lang = detect(text) |
|
except: |
|
lang = "unknown" |
|
|
|
if lang == "en": |
|
tokenizer = english_tokenizer |
|
model = english_model |
|
inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True).to(device) |
|
with torch.no_grad(): |
|
outputs = model(**inputs) |
|
probs = torch.sigmoid(outputs).squeeze().cpu().tolist() |
|
return {"language": "English", "predictions": dict(zip(english_labels, probs))} |
|
|
|
else: |
|
tokenizer = hinglish_tokenizer |
|
model = hinglish_model |
|
inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True).to(device) |
|
with torch.no_grad(): |
|
outputs = model(**inputs) |
|
probs = torch.softmax(outputs, dim=1).squeeze().cpu().tolist() |
|
return {"language": "Hinglish", "predictions": dict(zip(hinglish_labels, probs))} |
|
|
|
@app.get("/") |
|
def root(): |
|
return {"message": "API is running"} |
|
|
|
|