Commit
·
99575b1
1
Parent(s):
95b94fd
Increase threshold
Browse files- api.py +17 -5
- example_uses.md +3 -3
- inference_example.py +1 -1
- inference_lstm.py +1 -1
- knowledge_distillation.py +5 -3
- trainer.py +2 -2
- utils/convert_vihsd_gemini.py +4 -4
api.py
CHANGED
@@ -40,7 +40,7 @@ def load_model_lstm():
|
|
40 |
model = model.to(device)
|
41 |
return model, device
|
42 |
|
43 |
-
def inference(model, device, comments: str | list, threshold: float = 0.
|
44 |
if isinstance(comments, str):
|
45 |
comments = [comments]
|
46 |
elif not isinstance(comments, list):
|
@@ -73,6 +73,7 @@ def inference(model, device, comments: str | list, threshold: float = 0.5):
|
|
73 |
|
74 |
# Keep only the probs that are above the threshold (to prevent false positive), else set it to 0 (NORMAL, in this case unconclusive)
|
75 |
probs = torch.where(probs > threshold, probs, 0.0)
|
|
|
76 |
# Argmax over each group of classes_per_group
|
77 |
predictions = probs.argmax(dim=-1)
|
78 |
else:
|
@@ -95,11 +96,22 @@ def inference(model, device, comments: str | list, threshold: float = 0.5):
|
|
95 |
if __name__ == "__main__":
|
96 |
|
97 |
model, device = load_model_bert()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
98 |
comments = [
|
99 |
-
"
|
100 |
-
"
|
101 |
-
"
|
102 |
-
"
|
|
|
|
|
|
|
103 |
]
|
104 |
predictions = inference(model, device, comments)
|
105 |
print("BERT Predictions:")
|
|
|
40 |
model = model.to(device)
|
41 |
return model, device
|
42 |
|
43 |
+
def inference(model, device, comments: str | list, threshold: float = 0.55):
|
44 |
if isinstance(comments, str):
|
45 |
comments = [comments]
|
46 |
elif not isinstance(comments, list):
|
|
|
73 |
|
74 |
# Keep only the probs that are above the threshold (to prevent false positive), else set it to 0 (NORMAL, in this case unconclusive)
|
75 |
probs = torch.where(probs > threshold, probs, 0.0)
|
76 |
+
print("Probabilities: ", probs)
|
77 |
# Argmax over each group of classes_per_group
|
78 |
predictions = probs.argmax(dim=-1)
|
79 |
else:
|
|
|
96 |
if __name__ == "__main__":
|
97 |
|
98 |
model, device = load_model_bert()
|
99 |
+
'''comments = [
|
100 |
+
"Em ăn hoành thánh sáng bị khó chịu mắc ói quá bỏ ăn trưa luôn. Các thím thường hay uống gì cho đỡ vậy? Em tính làm gói gừng pha uống",
|
101 |
+
"Quan trọng là năm nay có tham gia những lễ hội có tính chất, quy mô và bối cảnh y hệt vậy không? Chứ tôi nói thật, dù ở bất cứ đâu mà tập trung đông đến mức không tiến không lùi như này được thì đều nguy hiểm. Khoan nói về giẫm đạp, chỉ riêng việc có sự cố đột xuất xảy ra thì chuyện cấp cứu nó sẽ vô cùng khó khăn và mất rất nhiều thời gian. Bởi vậy, tôi từ chối tham gia tất cả lễ hội nơi mà số người vượt tải đến mức không thể nhúc nhích như thế này.",
|
102 |
+
"Còn phải tốn hơn nữa mới được",
|
103 |
+
"Mình k có ý kích dục fen nhé :v Có sao kể vậy thôi.",
|
104 |
+
"Này là lúc trước khi gặp P hả bác? Em thắc mắc là bác có thể thẳng thừng chặn C - người bác yêu như vậy à?",
|
105 |
+
"Thì mượt hơn là đúng thôi. Mới phát triển thì không có nhiều tính năng, không có nhiều app thì chả mượt",
|
106 |
+
]'''
|
107 |
comments = [
|
108 |
+
"đúng là vozer, nhiều thằng sống ngu và ích kỷ vcl, nếu như người yêu nó cần 1 trái thận, lúc đó bản thân suy nghĩ tính toán thì ok, này chạy xe có 40km mà tính toán chi ly, mua cái váy mà mặc đi",
|
109 |
+
"Khác mẹ gì tàu khựa, bơm tiền cho đám NGO woke đi biểu tình phá lại bọn tây lông thôi. Chó chê mèo lắm lông. À mà acc Emma Roberts bị ban rồi à mày",
|
110 |
+
"đùa, cái shop thế mà cũng bảo chính hãng, vả vỡ alo nó đi. ra trung tâm thương mại, hay cửa hàng chính hãng mà mua.",
|
111 |
+
"qua thớt này của nó thì 90% là xiaolol rùi",
|
112 |
+
"thằng này chuyên đăng bài để hả hê, khóa mõm nó đi mod",
|
113 |
+
"Đm nhẫm vào đuổi con bò đỏ này nó giãy nảy cắn người kinh thật @@ Tao có hay ko liên quan lol gì mà mày có vẻ cay cú vkl nhỉ, chắc gato với tao hả ))",
|
114 |
+
"Sao thế óc chó, bị chửi cho ngu người rồi à =]] thứ ngu học chả biết mẹ gì vào sủa như đúng rồi =]]",
|
115 |
]
|
116 |
predictions = inference(model, device, comments)
|
117 |
print("BERT Predictions:")
|
example_uses.md
CHANGED
@@ -1,16 +1,16 @@
|
|
1 |
|
2 |
## Example uses:
|
3 |
|
4 |
-
- Train with
|
5 |
```
|
6 |
python ./train.py --bert_model "vinai/phobert-base-v2" --train_data_path "./datasets/train.csv" --val_data_path "./datasets/dev.csv" --test_data_path "./datasets/test.csv" --label_column "individual" "groups" "religion/creed" "race/ethnicity" "politics" --text_column "content" --epochs 7 --num_classes 4 --output "./vietnamese_hate_speech_detection_phobert"
|
7 |
```
|
8 |
-
- Inference with
|
9 |
```
|
10 |
python ./inference_example.py --bert_model "vinai/phobert-base-v2" --model_path "./vietnamese_hate_speech_detection_phobert/vinai_phobert-base-v2_finetuned.pth" --num_classes 4 --label_column "individual" "groups" "religion/creed" "race/ethnicity" "politics" --text_column "content" --data_path "./datasets/test.csv" --inference_batch_limit 10
|
11 |
```
|
12 |
|
13 |
-
- Train LSTM model from
|
14 |
```
|
15 |
python ./distill_bert_to_lstm.py --bert_model "vinai/phobert-base-v2" --bert_model_path "./vietnamese_hate_speech_detection_phobert/vinai_phobert-base-v2_finetuned.pth" --output_dir "./vietnamese_hate_speech_detection_phobert" --batch_size 32 --epochs 10 --train_data_path "./datasets/train.csv" --val_data_path "./datasets/dev.csv" --test_data_path "./datasets/test.csv" --label_column "individual" "groups" "religion/creed" "race/ethnicity" "politics" --text_column "content" --num_classes 4
|
16 |
```
|
|
|
1 |
|
2 |
## Example uses:
|
3 |
|
4 |
+
- Train with PhoBERT model (train.csv is ViTHSD dataset with 4 classes each for 5 categories)
|
5 |
```
|
6 |
python ./train.py --bert_model "vinai/phobert-base-v2" --train_data_path "./datasets/train.csv" --val_data_path "./datasets/dev.csv" --test_data_path "./datasets/test.csv" --label_column "individual" "groups" "religion/creed" "race/ethnicity" "politics" --text_column "content" --epochs 7 --num_classes 4 --output "./vietnamese_hate_speech_detection_phobert"
|
7 |
```
|
8 |
+
- Inference with PhoBERT model (test_data.csv is test dataset with 4 classes each for 5 categories like ViTHSD)
|
9 |
```
|
10 |
python ./inference_example.py --bert_model "vinai/phobert-base-v2" --model_path "./vietnamese_hate_speech_detection_phobert/vinai_phobert-base-v2_finetuned.pth" --num_classes 4 --label_column "individual" "groups" "religion/creed" "race/ethnicity" "politics" --text_column "content" --data_path "./datasets/test.csv" --inference_batch_limit 10
|
11 |
```
|
12 |
|
13 |
+
- Train LSTM model from PhoBERT model using distillation (train dataset should be the same as distillation training dataset)
|
14 |
```
|
15 |
python ./distill_bert_to_lstm.py --bert_model "vinai/phobert-base-v2" --bert_model_path "./vietnamese_hate_speech_detection_phobert/vinai_phobert-base-v2_finetuned.pth" --output_dir "./vietnamese_hate_speech_detection_phobert" --batch_size 32 --epochs 10 --train_data_path "./datasets/train.csv" --val_data_path "./datasets/dev.csv" --test_data_path "./datasets/test.csv" --label_column "individual" "groups" "religion/creed" "race/ethnicity" "politics" --text_column "content" --num_classes 4
|
16 |
```
|
inference_example.py
CHANGED
@@ -19,7 +19,7 @@ if __name__ == "__main__":
|
|
19 |
parser.add_argument("--class_names", type=str, nargs='+', required=False, help="List of class names for classification")
|
20 |
parser.add_argument("--inference_batch_limit", type=int, default=-1, help="Limit for inference batch counts")
|
21 |
parser.add_argument("--print_predictions", type=bool, default=False, help="Print predictions to console")
|
22 |
-
parser.add_argument("--threshold", type=float, default=0.
|
23 |
args = parser.parse_args()
|
24 |
|
25 |
class_names = args.class_names
|
|
|
19 |
parser.add_argument("--class_names", type=str, nargs='+', required=False, help="List of class names for classification")
|
20 |
parser.add_argument("--inference_batch_limit", type=int, default=-1, help="Limit for inference batch counts")
|
21 |
parser.add_argument("--print_predictions", type=bool, default=False, help="Print predictions to console")
|
22 |
+
parser.add_argument("--threshold", type=float, default=0.55, help="Threshold for classification")
|
23 |
args = parser.parse_args()
|
24 |
|
25 |
class_names = args.class_names
|
inference_lstm.py
CHANGED
@@ -30,7 +30,7 @@ if __name__ == "__main__":
|
|
30 |
parser.add_argument("--hidden_dim", type=int, default=256, help="Hidden dimension of LSTM")
|
31 |
parser.add_argument("--num_layers", type=int, default=2, help="Number of LSTM layers")
|
32 |
parser.add_argument("--dropout", type=float, default=0.5, help="Dropout probability")
|
33 |
-
parser.add_argument("--threshold", type=float, default=0.
|
34 |
args = parser.parse_args()
|
35 |
|
36 |
class_names = args.class_names
|
|
|
30 |
parser.add_argument("--hidden_dim", type=int, default=256, help="Hidden dimension of LSTM")
|
31 |
parser.add_argument("--num_layers", type=int, default=2, help="Number of LSTM layers")
|
32 |
parser.add_argument("--dropout", type=float, default=0.5, help="Dropout probability")
|
33 |
+
parser.add_argument("--threshold", type=float, default=0.55, help="Threshold for classification")
|
34 |
args = parser.parse_args()
|
35 |
|
36 |
class_names = args.class_names
|
knowledge_distillation.py
CHANGED
@@ -231,7 +231,7 @@ class DistillationTrainer:
|
|
231 |
logger.info(f"Test Loss: {test_loss:.4f}, Test Acc: {test_acc:.4f}, Test F1: {test_f1:.4f}")
|
232 |
print(f"Test Loss: {test_loss:.4f}, Test Acc: {test_acc:.4f}, Test F1: {test_f1:.4f}")
|
233 |
|
234 |
-
def evaluate(self, data_loader=None, phase="Validation"):
|
235 |
"""
|
236 |
Evaluate the student model
|
237 |
"""
|
@@ -284,9 +284,11 @@ class DistillationTrainer:
|
|
284 |
classes_per_group = total_classes // self.num_categories
|
285 |
# Group every classes_per_group values along dim=1
|
286 |
reshaped = student_logits.view(student_logits.size(0), -1, classes_per_group) # shape: (batch, self., classes_per_group)
|
287 |
-
|
|
|
|
|
288 |
# Argmax over each group of classes_per_group
|
289 |
-
preds =
|
290 |
else:
|
291 |
_, preds = torch.max(student_logits, 1)
|
292 |
all_preds = np.append(all_preds, preds.cpu().numpy())
|
|
|
231 |
logger.info(f"Test Loss: {test_loss:.4f}, Test Acc: {test_acc:.4f}, Test F1: {test_f1:.4f}")
|
232 |
print(f"Test Loss: {test_loss:.4f}, Test Acc: {test_acc:.4f}, Test F1: {test_f1:.4f}")
|
233 |
|
234 |
+
def evaluate(self, data_loader=None, phase="Validation", threshold=0.55):
|
235 |
"""
|
236 |
Evaluate the student model
|
237 |
"""
|
|
|
284 |
classes_per_group = total_classes // self.num_categories
|
285 |
# Group every classes_per_group values along dim=1
|
286 |
reshaped = student_logits.view(student_logits.size(0), -1, classes_per_group) # shape: (batch, self., classes_per_group)
|
287 |
+
probs = F.softmax(reshaped, dim=1)
|
288 |
+
# Keep only the probs that are above the threshold (to prevent false positive), else set it to 0 (NORMAL, in this case unconclusive)
|
289 |
+
probs = torch.where(probs > threshold, probs, 0.0)
|
290 |
# Argmax over each group of classes_per_group
|
291 |
+
preds = probs.argmax(dim=-1)
|
292 |
else:
|
293 |
_, preds = torch.max(student_logits, 1)
|
294 |
all_preds = np.append(all_preds, preds.cpu().numpy())
|
trainer.py
CHANGED
@@ -219,7 +219,7 @@ class Trainer:
|
|
219 |
f"Loss: {test_loss:.4f}, Acc: {test_acc:.4f}, F1: {test_f1:.4f}, ",
|
220 |
f"Precision: {test_precision:.4f}, Recall: {test_recall:.4f}")
|
221 |
|
222 |
-
def evaluate(self, data_loader, phase="Validation"):
|
223 |
"""
|
224 |
Evaluation function for both validation and test sets
|
225 |
"""
|
@@ -280,7 +280,7 @@ class Trainer:
|
|
280 |
|
281 |
# Softmax and apply threshold
|
282 |
probs = torch.softmax(reshaped, dim=1)
|
283 |
-
probs = torch.where(probs >
|
284 |
# Argmax over each group of classes_per_group
|
285 |
preds = probs.argmax(dim=-1)
|
286 |
else:
|
|
|
219 |
f"Loss: {test_loss:.4f}, Acc: {test_acc:.4f}, F1: {test_f1:.4f}, ",
|
220 |
f"Precision: {test_precision:.4f}, Recall: {test_recall:.4f}")
|
221 |
|
222 |
+
def evaluate(self, data_loader, phase="Validation", threshold=0.55):
|
223 |
"""
|
224 |
Evaluation function for both validation and test sets
|
225 |
"""
|
|
|
280 |
|
281 |
# Softmax and apply threshold
|
282 |
probs = torch.softmax(reshaped, dim=1)
|
283 |
+
probs = torch.where(probs > threshold, probs, 0.0)
|
284 |
# Argmax over each group of classes_per_group
|
285 |
preds = probs.argmax(dim=-1)
|
286 |
else:
|
utils/convert_vihsd_gemini.py
CHANGED
@@ -42,7 +42,7 @@ def classify_text(model, text):
|
|
42 |
print(f"Error classifying text: {e}")
|
43 |
return None
|
44 |
|
45 |
-
def process_file(input_file, output_file, model, rate_limit_pause=4):
|
46 |
"""Process a single CSV file to match the test.csv format"""
|
47 |
print(f"Processing {input_file}...")
|
48 |
|
@@ -53,9 +53,9 @@ def process_file(input_file, output_file, model, rate_limit_pause=4):
|
|
53 |
print(f"Error reading {input_file}: {e}")
|
54 |
return
|
55 |
|
56 |
-
# Rename column
|
57 |
-
if
|
58 |
-
df.rename(columns={
|
59 |
elif 'content' not in df.columns:
|
60 |
print(f"Error: 'content' column not found in {input_file}")
|
61 |
return
|
|
|
42 |
print(f"Error classifying text: {e}")
|
43 |
return None
|
44 |
|
45 |
+
def process_file(input_file, output_file, model, rate_limit_pause=4, text_col="free_text"):
|
46 |
"""Process a single CSV file to match the test.csv format"""
|
47 |
print(f"Processing {input_file}...")
|
48 |
|
|
|
53 |
print(f"Error reading {input_file}: {e}")
|
54 |
return
|
55 |
|
56 |
+
# Rename column text_col to content
|
57 |
+
if text_col in df.columns:
|
58 |
+
df.rename(columns={text_col: 'content'}, inplace=True)
|
59 |
elif 'content' not in df.columns:
|
60 |
print(f"Error: 'content' column not found in {input_file}")
|
61 |
return
|