from typing import List from pydantic import BaseModel import pdfplumber from fastapi import UploadFile from gliner import GLiNER import logging # Set up logging logging.basicConfig(level=logging.DEBUG) logger = logging.getLogger(__name__) class Entity(BaseModel): entity: str context: str start: int end: int # Curated medical labels MEDICAL_LABELS = [ "gene", "protein", "protein_isoform", "cell", "disease", "phenotypic_feature", "clinical_finding", "anatomical_entity", "pathway", "biological_process", "drug", "small_molecule", "food_additive", "chemical_mixture", "molecular_entity", "clinical_intervention", "clinical_trial", "hospitalization", "geographic_location", "environmental_feature", "environmental_process", "publication", "journal_article", "book", "patent", "dataset", "study_result", "human", "mammal", "plant", "virus", "bacterium", "cell_line", "biological_sex", "clinical_attribute", "socioeconomic_attribute", "environmental_exposure", "drug_exposure", "procedure", "treatment", "device", "diagnostic_aid", "event" ] # Initialize model gliner_model = GLiNER.from_pretrained("knowledgator/gliner-multitask-large-v0.5") def extract_entities_from_pdf(file: UploadFile) -> List[Entity]: """ Extract medical entities from a PDF file using GLiNER. Args: file (UploadFile): The uploaded PDF file Returns: List[Entity]: List of extracted entities with their context """ logger.debug(f"Starting extraction for file: {file.filename}") try: # Create a temporary file to handle the upload with pdfplumber.open(file.file) as pdf: logger.debug(f"Successfully opened PDF with {len(pdf.pages)} pages") # Join all pages into single string pdf_text = " ".join(p.extract_text() for p in pdf.pages) logger.debug(f"Extracted text length: {len(pdf_text)} characters") # Extract entities using GLiNER logger.debug("Starting GLiNER entity extraction") entities = gliner_model.predict_entities(pdf_text, MEDICAL_LABELS, threshold=0.7) logger.debug(f"Found {len(entities)} entities") # Convert to our Entity model format result = [] for ent in entities: if len(ent["text"]) <= 2: # Skip very short entities continue # Find the context (text surrounding the entity) start_idx = pdf_text.find(ent["text"]) if start_idx != -1: # Get surrounding context (50 chars before and after) context_start = max(0, start_idx - 50) context_end = min(len(pdf_text), start_idx + len(ent["text"]) + 50) context = pdf_text[context_start:context_end] result.append(Entity( entity=ent["text"], context=context, start=start_idx - context_start, # Adjust start position relative to context end=start_idx - context_start + len(ent["text"]) )) logger.debug(f"Returning {len(result)} processed entities") return result except Exception as e: logger.error(f"Error during extraction: {str(e)}", exc_info=True) raise