koyu008 commited on
Commit
2d04c0e
·
verified ·
1 Parent(s): 99969dd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +36 -24
app.py CHANGED
@@ -6,6 +6,8 @@ from transformers import DistilBertTokenizer, DistilBertModel, AutoModel, AutoTo
6
  from langdetect import detect
7
  from huggingface_hub import snapshot_download
8
  import os
 
 
9
 
10
  # Device
11
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
@@ -84,34 +86,44 @@ app.add_middleware(
84
 
85
 
86
  class TextIn(BaseModel):
87
- text: str
88
 
89
 
 
90
  @app.post("/api/predict")
91
  def predict(data: TextIn):
92
- text = data.text
93
- try:
94
- lang = detect(text)
95
- except:
96
- lang = "unknown"
97
-
98
- if lang == "en":
99
- tokenizer = english_tokenizer
100
- model = english_model
101
- inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True).to(device)
102
- with torch.no_grad():
103
- outputs = model(**inputs)
104
- probs = torch.sigmoid(outputs).squeeze().cpu().tolist()
105
- return {"language": "English", "predictions": dict(zip(english_labels, probs))}
106
-
107
- else:
108
- tokenizer = hinglish_tokenizer
109
- model = hinglish_model
110
- inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True).to(device)
111
- with torch.no_grad():
112
- outputs = model(**inputs)
113
- probs = torch.softmax(outputs, dim=1).squeeze().cpu().tolist()
114
- return {"language": "Hinglish", "predictions": dict(zip(hinglish_labels, probs))}
 
 
 
 
 
 
 
 
 
115
 
116
 
117
  @app.get("/")
 
6
  from langdetect import detect
7
  from huggingface_hub import snapshot_download
8
  import os
9
+ from typing import List
10
+
11
 
12
  # Device
13
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
86
 
87
 
88
  class TextIn(BaseModel):
89
+ texts: List[str]
90
 
91
 
92
+ @app.post("/api/predict")
93
  @app.post("/api/predict")
94
  def predict(data: TextIn):
95
+ results = []
96
+
97
+ for text in data.texts:
98
+ try:
99
+ lang = detect(text)
100
+ except:
101
+ lang = "unknown"
102
+
103
+ if lang == "en":
104
+ tokenizer = english_tokenizer
105
+ model = english_model
106
+ inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True).to(device)
107
+ with torch.no_grad():
108
+ outputs = model(**inputs)
109
+ probs = torch.sigmoid(outputs).squeeze().cpu().tolist()
110
+ predictions = dict(zip(english_labels, probs))
111
+ else:
112
+ tokenizer = hinglish_tokenizer
113
+ model = hinglish_model
114
+ inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True).to(device)
115
+ with torch.no_grad():
116
+ outputs = model(**inputs)
117
+ probs = torch.softmax(outputs, dim=1).squeeze().cpu().tolist()
118
+ predictions = dict(zip(hinglish_labels, probs))
119
+
120
+ results.append({
121
+ "text": text,
122
+ "language": lang if lang in ["en", "hi"] else "unknown",
123
+ "predictions": predictions
124
+ })
125
+
126
+ return {"results": results}
127
 
128
 
129
  @app.get("/")