Overview
Advances in the field of biomedicine have been a major focus of the AI community for many years. With the recent rise of large language models (LLMs), more applications are now leveraging these innovations to support biomedicine and healthcare.
We present MedConclusion, a fine-tuned LLM based on Phi3-medium, trained on over 250,000 PubMed articles. MedConclusion processes research articles to generate clear and concise conclusions. This specialized model has a wide range of potential applications, including:
Applications
Clinical Decision Support Systems: By integrating the model into hospital systems or pharmaceutical research institutes, healthcare providers can access summarized research findings to support evidence-based clinical decisions.
Medical Education and Training: providing students with concise summaries that highlight key information.
Public Health Policy Development: Policymakers can use research summaries to make informed decisions about public health strategies.
Research Recommendation Systems: helping researchers stay updated with minimal effort.
Biomedical Search Engines: Enhanced search results with auto-generated conclusions offer users quick insights into the relevance and key takeaways of each paper.
Academic Tools and Writing Assistance: Academic platforms can use the model to create quick summaries for papers under review, improving their visibility and accessibility.
Training Dataset
MedConclusion is fine-tuned using the PubMedQA dataset, specifically on the PQA-A and PQA-U training subsets.
For a complete description of the PubMedQA dataset, please visit the original source: https://github.com/pubmedqa/pubmedqa
The instructions and training datasets used are:
pqau_genconc.jsonl
andpqaa_genconc.jsonl
for training.testset_genconc.jsonl
for validation.pubmedqa_testset.csv
for inference.
The datasets are available under med-rcq/med-rcq-dataset
: https://huggingface.co/datasets/med-rcq/med-rcq-dataset/tree/main
Training and Inference Parameters
Training Parameter | MedConclusion |
---|---|
Learning rate | 2e-04 |
Seed | 42 |
Scheduler | cosine |
Warmup Ratio | 0.05 |
Optimizer | AdamW |
Gradient Accumulation steps | 4 |
Train batch size per device | 8 |
Effective Batch size | 192 |
Evaluation batch size | 4 |
Cut-off Length | 1024 |
Number of GPU | 6 |
LoRA Rank | 32 |
LoRA Alpha | 32 |
LoRA Dropout | 0.05 |
LoRA Target | All |
Number of epochs | 1 |
Max Grad Norm | 1.0 |
Training Time | 11.5 hours |
Inference Parameter | MedConclusion |
---|---|
Temperature | 0.01 |
Max Token | 250 |
Environment Setup
OS: Ubuntu 22.04.3
GPU: A40 or RTX A6000, CUDA 12.4
To setup the environment use the following commands: curl -O https://repo.anaconda.com/archive/Anaconda3-2024.02-1-Linux-x86_64.sh; /bin/bash Anaconda3-2024.02-1-Linux-x86_64.sh -b -p /opt/conda; source ~/.bashrc; export PATH=/opt/conda/bin:$PATH; source /opt/conda/bin/activate; conda create -n medrcq_env python=3.11.7 -y; conda activate medrcq_env; pip install torch==2.5.1 transformers==4.48.0 pandas==2.1.4;pip install flash-attn==2.7.3
The code are detailed below.
from transformers import pipeline,set_seed
import torch
import pandas as pd
import argparse
MODEL_PATH="med-rcq/MedConclusion"
set_seed(42)
SYSTEM_PROMPT='''You are a helpful medical assistant. Write a conclusion for the following article:\n
Title: _TITLE_
_CONTEXT_
Conclusion:'''
pipe = pipeline(
"text-generation",
model=MODEL_PATH,
model_kwargs={"torch_dtype": torch.bfloat16},
trust_remote_code=True,
do_sample=True,
temperature=0.01,
device="cuda", # replace with "mps" to run on a Mac device
)
def generate_ai_conclusion(prompt):
"""
Generates medical conclusion based on a given prompt.
Args:
prompt (str): The input prompt for the model.
Returns:
str: The generated conclusion.
"""
messages = [{"role": "user", "content": prompt}]
outputs = pipe(messages, max_new_tokens=250)
assistant_response = outputs[0]["generated_text"][-1]["content"].strip()
return assistant_response
# Function to process the CSV file
def process_csv(input_file, output_file):
"""
Processes a CSV file by generating medical conclusions for each row. Each row represent a pubmed article.
Args:
input_file (str): Path to the input CSV file.
output_file (str): Path to save the processed CSV file with the generated conclusion
"""
# Read the input CSV file into a DataFrame
try:
df = pd.read_csv(input_file)
except FileNotFoundError:
print(f"Error: File {input_file} not found.")
return
except pd.errors.EmptyDataError:
print("Error: Input file is empty.")
return
cols_order = ['ID', 'Question', 'Context_with_label','LONG_ANSWER','final_decision']
df = df[cols_order]
# Loop over each row in the DataFrame
for index, row in df.iterrows():
# Extract the relevant columns
article_id = row['ID']
# Pubmed article title is in a Question form
title = row['Question']
# context_string represent the Pubmed article without the conclusion or the title. It include the labels like "background", "Methods"...etc
context_string=row['Context_with_label']
# long answer represent Pubmed article conclusion section
long_answer = row['LONG_ANSWER']
# The final decision is either yes or no or maybe, it depends on what both annotators agreed on
final_decision = row['final_decision']
print("\n########## INDEX:"+str(index)+" ## QID:"+str(article_id)+" ##########\n")
#Prepare system prompt
formatted_prompt=SYSTEM_PROMPT.replace("_TITLE_",title)
formatted_prompt=formatted_prompt.replace("_CONTEXT_",context_string)
# Call the LLM model to generate the conclusion using the constructed prompt
generated_conclusion = generate_ai_conclusion(formatted_prompt)
print(generated_conclusion)
#Save the generated conclusion
df.at[index, 'Medconc_Generated_conclusion'] = generated_conclusion
# Save output in CSV file
df.to_csv(output_file, index=False)
if __name__ == "__main__":
# output_file
parser = argparse.ArgumentParser(description="Process a CSV file to generate conclusions.")
# input_file = 'pubmedqa_testset.csv'
parser.add_argument("input_file", help="Path to the input CSV file")
# write the name of the output file
parser.add_argument("output_file", help="Path to save the processed CSV file")
args = parser.parse_args()
input_file = args.input_file
output_file = args.output_file
# Process the file
process_csv(input_file, output_file)
- Downloads last month
- 5