jesse-tong commited on
Commit
99575b1
·
1 Parent(s): 95b94fd

Increase threshold

Browse files
api.py CHANGED
@@ -40,7 +40,7 @@ def load_model_lstm():
40
  model = model.to(device)
41
  return model, device
42
 
43
- def inference(model, device, comments: str | list, threshold: float = 0.5):
44
  if isinstance(comments, str):
45
  comments = [comments]
46
  elif not isinstance(comments, list):
@@ -73,6 +73,7 @@ def inference(model, device, comments: str | list, threshold: float = 0.5):
73
 
74
  # Keep only the probs that are above the threshold (to prevent false positive), else set it to 0 (NORMAL, in this case unconclusive)
75
  probs = torch.where(probs > threshold, probs, 0.0)
 
76
  # Argmax over each group of classes_per_group
77
  predictions = probs.argmax(dim=-1)
78
  else:
@@ -95,11 +96,22 @@ def inference(model, device, comments: str | list, threshold: float = 0.5):
95
  if __name__ == "__main__":
96
 
97
  model, device = load_model_bert()
 
 
 
 
 
 
 
 
98
  comments = [
99
- "Để avata bít ngay ngu hơn chó",
100
- "Hàn Quốc chửi dân Đông Lào đây hậu quả",
101
- "Nguyễn Thuận =)) tưởng rừng vậy",
102
- "@công danh nguyen thể chế chính trị khác hẳn tư tưởng xã hội nhé. Con cờ hó china liên quan cmn gì?"
 
 
 
103
  ]
104
  predictions = inference(model, device, comments)
105
  print("BERT Predictions:")
 
40
  model = model.to(device)
41
  return model, device
42
 
43
+ def inference(model, device, comments: str | list, threshold: float = 0.55):
44
  if isinstance(comments, str):
45
  comments = [comments]
46
  elif not isinstance(comments, list):
 
73
 
74
  # Keep only the probs that are above the threshold (to prevent false positive), else set it to 0 (NORMAL, in this case unconclusive)
75
  probs = torch.where(probs > threshold, probs, 0.0)
76
+ print("Probabilities: ", probs)
77
  # Argmax over each group of classes_per_group
78
  predictions = probs.argmax(dim=-1)
79
  else:
 
96
  if __name__ == "__main__":
97
 
98
  model, device = load_model_bert()
99
+ '''comments = [
100
+ "Em ăn hoành thánh sáng bị khó chịu mắc ói quá bỏ ăn trưa luôn. Các thím thường hay uống gì cho đỡ vậy? Em tính làm gói gừng pha uống",
101
+ "Quan trọng là năm nay có tham gia những lễ hội có tính chất, quy mô và bối cảnh y hệt vậy không? Chứ tôi nói thật, dù ở bất cứ đâu mà tập trung đông đến mức không tiến không lùi như này được thì đều nguy hiểm. Khoan nói về giẫm đạp, chỉ riêng việc có sự cố đột xuất xảy ra thì chuyện cấp cứu nó sẽ vô cùng khó khăn và mất rất nhiều thời gian. Bởi vậy, tôi từ chối tham gia tất cả lễ hội nơi mà số người vượt tải đến mức không thể nhúc nhích như thế này.",
102
+ "Còn phải tốn hơn nữa mới được",
103
+ "Mình k có ý kích dục fen nhé :v Có sao kể vậy thôi.",
104
+ "Này là lúc trước khi gặp P hả bác? Em thắc mắc là bác có thể thẳng thừng chặn C - người bác yêu như vậy à?",
105
+ "Thì mượt hơn là đúng thôi. Mới phát triển thì không có nhiều tính năng, không có nhiều app thì chả mượt",
106
+ ]'''
107
  comments = [
108
+ "đúng vozer, nhiều thằng sống ngu ích kỷ vcl, nếu như người yêu nó cần 1 trái thận, lúc đó bản thân suy nghĩ tính toán thì ok, này chạy xe có 40km mà tính toán chi ly, mua cái váy mà mặc đi",
109
+ "Khác mẹ tàu khựa, bơm tiền cho đám NGO woke đi biểu tình phá lại bọn tây lông thôi. Chó chê mèo lắm lông. À mà acc Emma Roberts bị ban rồi à mày",
110
+ "đùa, cái shop thế cũng bảo chính hãng, vả vỡ alo nó đi. ra trung tâm thương mại, hay cửa hàng chính hãng mà mua.",
111
+ "qua thớt này của thì 90% xiaolol rùi",
112
+ "thằng này chuyên đăng bài để hả hê, khóa mõm nó đi mod",
113
+ "Đm nhẫm vào đuổi con bò đỏ này nó giãy nảy cắn người kinh thật @@ Tao có hay ko liên quan lol gì mà mày có vẻ cay cú vkl nhỉ, chắc gato với tao hả ))",
114
+ "Sao thế óc chó, bị chửi cho ngu người rồi à =]] thứ ngu học chả biết mẹ gì vào sủa như đúng rồi =]]",
115
  ]
116
  predictions = inference(model, device, comments)
117
  print("BERT Predictions:")
example_uses.md CHANGED
@@ -1,16 +1,16 @@
1
 
2
  ## Example uses:
3
 
4
- - Train with BERT model (train.csv is ViTHSD dataset with 4 classes each for 5 categories)
5
  ```
6
  python ./train.py --bert_model "vinai/phobert-base-v2" --train_data_path "./datasets/train.csv" --val_data_path "./datasets/dev.csv" --test_data_path "./datasets/test.csv" --label_column "individual" "groups" "religion/creed" "race/ethnicity" "politics" --text_column "content" --epochs 7 --num_classes 4 --output "./vietnamese_hate_speech_detection_phobert"
7
  ```
8
- - Inference with BERT model (test_data.csv is test dataset with 4 classes each for 5 categories like ViTHSD)
9
  ```
10
  python ./inference_example.py --bert_model "vinai/phobert-base-v2" --model_path "./vietnamese_hate_speech_detection_phobert/vinai_phobert-base-v2_finetuned.pth" --num_classes 4 --label_column "individual" "groups" "religion/creed" "race/ethnicity" "politics" --text_column "content" --data_path "./datasets/test.csv" --inference_batch_limit 10
11
  ```
12
 
13
- - Train LSTM model from BERT model using distillation (train dataset should be the same as distillation training dataset)
14
  ```
15
  python ./distill_bert_to_lstm.py --bert_model "vinai/phobert-base-v2" --bert_model_path "./vietnamese_hate_speech_detection_phobert/vinai_phobert-base-v2_finetuned.pth" --output_dir "./vietnamese_hate_speech_detection_phobert" --batch_size 32 --epochs 10 --train_data_path "./datasets/train.csv" --val_data_path "./datasets/dev.csv" --test_data_path "./datasets/test.csv" --label_column "individual" "groups" "religion/creed" "race/ethnicity" "politics" --text_column "content" --num_classes 4
16
  ```
 
1
 
2
  ## Example uses:
3
 
4
+ - Train with PhoBERT model (train.csv is ViTHSD dataset with 4 classes each for 5 categories)
5
  ```
6
  python ./train.py --bert_model "vinai/phobert-base-v2" --train_data_path "./datasets/train.csv" --val_data_path "./datasets/dev.csv" --test_data_path "./datasets/test.csv" --label_column "individual" "groups" "religion/creed" "race/ethnicity" "politics" --text_column "content" --epochs 7 --num_classes 4 --output "./vietnamese_hate_speech_detection_phobert"
7
  ```
8
+ - Inference with PhoBERT model (test_data.csv is test dataset with 4 classes each for 5 categories like ViTHSD)
9
  ```
10
  python ./inference_example.py --bert_model "vinai/phobert-base-v2" --model_path "./vietnamese_hate_speech_detection_phobert/vinai_phobert-base-v2_finetuned.pth" --num_classes 4 --label_column "individual" "groups" "religion/creed" "race/ethnicity" "politics" --text_column "content" --data_path "./datasets/test.csv" --inference_batch_limit 10
11
  ```
12
 
13
+ - Train LSTM model from PhoBERT model using distillation (train dataset should be the same as distillation training dataset)
14
  ```
15
  python ./distill_bert_to_lstm.py --bert_model "vinai/phobert-base-v2" --bert_model_path "./vietnamese_hate_speech_detection_phobert/vinai_phobert-base-v2_finetuned.pth" --output_dir "./vietnamese_hate_speech_detection_phobert" --batch_size 32 --epochs 10 --train_data_path "./datasets/train.csv" --val_data_path "./datasets/dev.csv" --test_data_path "./datasets/test.csv" --label_column "individual" "groups" "religion/creed" "race/ethnicity" "politics" --text_column "content" --num_classes 4
16
  ```
inference_example.py CHANGED
@@ -19,7 +19,7 @@ if __name__ == "__main__":
19
  parser.add_argument("--class_names", type=str, nargs='+', required=False, help="List of class names for classification")
20
  parser.add_argument("--inference_batch_limit", type=int, default=-1, help="Limit for inference batch counts")
21
  parser.add_argument("--print_predictions", type=bool, default=False, help="Print predictions to console")
22
- parser.add_argument("--threshold", type=float, default=0.5, help="Threshold for classification")
23
  args = parser.parse_args()
24
 
25
  class_names = args.class_names
 
19
  parser.add_argument("--class_names", type=str, nargs='+', required=False, help="List of class names for classification")
20
  parser.add_argument("--inference_batch_limit", type=int, default=-1, help="Limit for inference batch counts")
21
  parser.add_argument("--print_predictions", type=bool, default=False, help="Print predictions to console")
22
+ parser.add_argument("--threshold", type=float, default=0.55, help="Threshold for classification")
23
  args = parser.parse_args()
24
 
25
  class_names = args.class_names
inference_lstm.py CHANGED
@@ -30,7 +30,7 @@ if __name__ == "__main__":
30
  parser.add_argument("--hidden_dim", type=int, default=256, help="Hidden dimension of LSTM")
31
  parser.add_argument("--num_layers", type=int, default=2, help="Number of LSTM layers")
32
  parser.add_argument("--dropout", type=float, default=0.5, help="Dropout probability")
33
- parser.add_argument("--threshold", type=float, default=0.5, help="Threshold for classification")
34
  args = parser.parse_args()
35
 
36
  class_names = args.class_names
 
30
  parser.add_argument("--hidden_dim", type=int, default=256, help="Hidden dimension of LSTM")
31
  parser.add_argument("--num_layers", type=int, default=2, help="Number of LSTM layers")
32
  parser.add_argument("--dropout", type=float, default=0.5, help="Dropout probability")
33
+ parser.add_argument("--threshold", type=float, default=0.55, help="Threshold for classification")
34
  args = parser.parse_args()
35
 
36
  class_names = args.class_names
knowledge_distillation.py CHANGED
@@ -231,7 +231,7 @@ class DistillationTrainer:
231
  logger.info(f"Test Loss: {test_loss:.4f}, Test Acc: {test_acc:.4f}, Test F1: {test_f1:.4f}")
232
  print(f"Test Loss: {test_loss:.4f}, Test Acc: {test_acc:.4f}, Test F1: {test_f1:.4f}")
233
 
234
- def evaluate(self, data_loader=None, phase="Validation"):
235
  """
236
  Evaluate the student model
237
  """
@@ -284,9 +284,11 @@ class DistillationTrainer:
284
  classes_per_group = total_classes // self.num_categories
285
  # Group every classes_per_group values along dim=1
286
  reshaped = student_logits.view(student_logits.size(0), -1, classes_per_group) # shape: (batch, self., classes_per_group)
287
-
 
 
288
  # Argmax over each group of classes_per_group
289
- preds = reshaped.argmax(dim=-1)
290
  else:
291
  _, preds = torch.max(student_logits, 1)
292
  all_preds = np.append(all_preds, preds.cpu().numpy())
 
231
  logger.info(f"Test Loss: {test_loss:.4f}, Test Acc: {test_acc:.4f}, Test F1: {test_f1:.4f}")
232
  print(f"Test Loss: {test_loss:.4f}, Test Acc: {test_acc:.4f}, Test F1: {test_f1:.4f}")
233
 
234
+ def evaluate(self, data_loader=None, phase="Validation", threshold=0.55):
235
  """
236
  Evaluate the student model
237
  """
 
284
  classes_per_group = total_classes // self.num_categories
285
  # Group every classes_per_group values along dim=1
286
  reshaped = student_logits.view(student_logits.size(0), -1, classes_per_group) # shape: (batch, self., classes_per_group)
287
+ probs = F.softmax(reshaped, dim=1)
288
+ # Keep only the probs that are above the threshold (to prevent false positive), else set it to 0 (NORMAL, in this case unconclusive)
289
+ probs = torch.where(probs > threshold, probs, 0.0)
290
  # Argmax over each group of classes_per_group
291
+ preds = probs.argmax(dim=-1)
292
  else:
293
  _, preds = torch.max(student_logits, 1)
294
  all_preds = np.append(all_preds, preds.cpu().numpy())
trainer.py CHANGED
@@ -219,7 +219,7 @@ class Trainer:
219
  f"Loss: {test_loss:.4f}, Acc: {test_acc:.4f}, F1: {test_f1:.4f}, ",
220
  f"Precision: {test_precision:.4f}, Recall: {test_recall:.4f}")
221
 
222
- def evaluate(self, data_loader, phase="Validation"):
223
  """
224
  Evaluation function for both validation and test sets
225
  """
@@ -280,7 +280,7 @@ class Trainer:
280
 
281
  # Softmax and apply threshold
282
  probs = torch.softmax(reshaped, dim=1)
283
- probs = torch.where(probs > 0.5, probs, 0.0)
284
  # Argmax over each group of classes_per_group
285
  preds = probs.argmax(dim=-1)
286
  else:
 
219
  f"Loss: {test_loss:.4f}, Acc: {test_acc:.4f}, F1: {test_f1:.4f}, ",
220
  f"Precision: {test_precision:.4f}, Recall: {test_recall:.4f}")
221
 
222
+ def evaluate(self, data_loader, phase="Validation", threshold=0.55):
223
  """
224
  Evaluation function for both validation and test sets
225
  """
 
280
 
281
  # Softmax and apply threshold
282
  probs = torch.softmax(reshaped, dim=1)
283
+ probs = torch.where(probs > threshold, probs, 0.0)
284
  # Argmax over each group of classes_per_group
285
  preds = probs.argmax(dim=-1)
286
  else:
utils/convert_vihsd_gemini.py CHANGED
@@ -42,7 +42,7 @@ def classify_text(model, text):
42
  print(f"Error classifying text: {e}")
43
  return None
44
 
45
- def process_file(input_file, output_file, model, rate_limit_pause=4):
46
  """Process a single CSV file to match the test.csv format"""
47
  print(f"Processing {input_file}...")
48
 
@@ -53,9 +53,9 @@ def process_file(input_file, output_file, model, rate_limit_pause=4):
53
  print(f"Error reading {input_file}: {e}")
54
  return
55
 
56
- # Rename column free_text to content
57
- if 'free_text' in df.columns:
58
- df.rename(columns={'free_text': 'content'}, inplace=True)
59
  elif 'content' not in df.columns:
60
  print(f"Error: 'content' column not found in {input_file}")
61
  return
 
42
  print(f"Error classifying text: {e}")
43
  return None
44
 
45
+ def process_file(input_file, output_file, model, rate_limit_pause=4, text_col="free_text"):
46
  """Process a single CSV file to match the test.csv format"""
47
  print(f"Processing {input_file}...")
48
 
 
53
  print(f"Error reading {input_file}: {e}")
54
  return
55
 
56
+ # Rename column text_col to content
57
+ if text_col in df.columns:
58
+ df.rename(columns={text_col: 'content'}, inplace=True)
59
  elif 'content' not in df.columns:
60
  print(f"Error: 'content' column not found in {input_file}")
61
  return