from typing import List from pydantic import BaseModel import pdfplumber from fastapi import UploadFile from gliner import GLiNER import logging import torch import re # Set up logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) class Entity(BaseModel): entity: str context: str start: int end: int # Curated medical labels MEDICAL_LABELS = [ "gene", "protein", "protein_isoform", "cell", "disease", "phenotypic_feature", "clinical_finding", "anatomical_entity", "pathway", "biological_process", "drug", "small_molecule", "food_additive", "chemical_mixture", "molecular_entity", "clinical_intervention", "clinical_trial", "hospitalization", "geographic_location", "environmental_feature", "environmental_process", "publication", "journal_article", "book", "patent", "dataset", "study_result", "human", "mammal", "plant", "virus", "bacterium", "cell_line", "biological_sex", "clinical_attribute", "socioeconomic_attribute", "environmental_exposure", "drug_exposure", "procedure", "treatment", "device", "diagnostic_aid", "event" ] # Check for GPU availability if torch.backends.mps.is_available(): device = torch.device("mps") elif torch.cuda.is_available(): device = torch.device("cuda") else: device = torch.device("cpu") logger.info(f"Using device: {device}") # Initialize model gliner_model = GLiNER.from_pretrained("knowledgator/gliner-multitask-large-v0.5") gliner_model.to(device) # Move model to GPU if available def chunk_text(text: str, max_tokens: int = 700) -> List[str]: """ Split text into chunks that respect sentence boundaries and token limit. We use 700 tokens to leave some margin for the model's special tokens. Args: text (str): Input text to chunk max_tokens (int): Maximum number of tokens per chunk Returns: List[str]: List of text chunks """ # Split into sentences (simple approach) sentences = re.split(r'(?<=[.!?])\s+', text) chunks = [] current_chunk = [] current_length = 0 for sentence in sentences: # Rough estimation of tokens (words + punctuation) sentence_tokens = len(re.findall(r'\w+|[^\w\s]', sentence)) if current_length + sentence_tokens > max_tokens: if current_chunk: # Save current chunk if it exists chunks.append(' '.join(current_chunk)) current_chunk = [] current_length = 0 current_chunk.append(sentence) current_length += sentence_tokens # Don't forget the last chunk if current_chunk: chunks.append(' '.join(current_chunk)) return chunks def extract_entities_from_pdf(file: UploadFile) -> List[Entity]: """ Extract medical entities from a PDF file using GLiNER. Args: file (UploadFile): The uploaded PDF file Returns: List[Entity]: List of extracted entities with their context """ logger.debug(f"Starting extraction for file: {file.filename}") try: # Create a temporary file to handle the upload with pdfplumber.open(file.file) as pdf: logger.info(f"Successfully opened PDF with {len(pdf.pages)} pages") # Join all pages into single string pdf_text = " ".join(p.extract_text() for p in pdf.pages) logger.info(f"Extracted text length: {len(pdf_text)} characters") # Split text into chunks text_chunks = chunk_text(pdf_text) logger.info(f"Split text into {len(text_chunks)} chunks") # Extract entities from each chunk all_entities = [] base_offset = 0 # Keep track of the absolute position in the original text for chunk in text_chunks: # Extract entities using GLiNER chunk_entities = gliner_model.predict_entities(chunk, MEDICAL_LABELS, threshold=0.7) # Process entities from this chunk for ent in chunk_entities: if len(ent["text"]) <= 2: # Skip very short entities continue # Just store the entity and its position for now start_idx = chunk.find(ent["text"]) if start_idx != -1: all_entities.append(Entity( entity=ent["text"], context="", # Will be filled later start=base_offset + start_idx, end=base_offset + start_idx + len(ent["text"]) )) base_offset += len(chunk) + 1 # +1 for the space between chunks # Now get context for all entities using the complete original text final_entities = [] for ent in all_entities: # Get surrounding context from the complete text context_start = max(0, ent.start - 50) context_end = min(len(pdf_text), ent.end + 50) context = pdf_text[context_start:context_end] final_entities.append(Entity( entity=ent.entity, context=context, start=ent.start, end=ent.end )) logger.info(f"Returning {len(final_entities)} processed entities") return final_entities except Exception as e: logger.error(f"Error during extraction: {str(e)}", exc_info=True) raise