Spaces:
Runtime error
Runtime error
from typing import List | |
from pydantic import BaseModel | |
import pdfplumber | |
from fastapi import UploadFile | |
from gliner import GLiNER | |
import logging | |
# Set up logging | |
logging.basicConfig(level=logging.DEBUG) | |
logger = logging.getLogger(__name__) | |
class Entity(BaseModel): | |
entity: str | |
context: str | |
start: int | |
end: int | |
# Curated medical labels | |
MEDICAL_LABELS = [ | |
"gene", "protein", "protein_isoform", "cell", "disease", | |
"phenotypic_feature", "clinical_finding", "anatomical_entity", | |
"pathway", "biological_process", "drug", "small_molecule", | |
"food_additive", "chemical_mixture", "molecular_entity", | |
"clinical_intervention", "clinical_trial", "hospitalization", | |
"geographic_location", "environmental_feature", "environmental_process", | |
"publication", "journal_article", "book", "patent", "dataset", | |
"study_result", "human", "mammal", "plant", "virus", "bacterium", | |
"cell_line", "biological_sex", "clinical_attribute", | |
"socioeconomic_attribute", "environmental_exposure", "drug_exposure", | |
"procedure", "treatment", "device", "diagnostic_aid", "event" | |
] | |
# Initialize model | |
gliner_model = GLiNER.from_pretrained("knowledgator/gliner-multitask-large-v0.5") | |
def extract_entities_from_pdf(file: UploadFile) -> List[Entity]: | |
""" | |
Extract medical entities from a PDF file using GLiNER. | |
Args: | |
file (UploadFile): The uploaded PDF file | |
Returns: | |
List[Entity]: List of extracted entities with their context | |
""" | |
logger.debug(f"Starting extraction for file: {file.filename}") | |
try: | |
# Create a temporary file to handle the upload | |
with pdfplumber.open(file.file) as pdf: | |
logger.debug(f"Successfully opened PDF with {len(pdf.pages)} pages") | |
# Join all pages into single string | |
pdf_text = " ".join(p.extract_text() for p in pdf.pages) | |
logger.debug(f"Extracted text length: {len(pdf_text)} characters") | |
# Extract entities using GLiNER | |
logger.debug("Starting GLiNER entity extraction") | |
entities = gliner_model.predict_entities(pdf_text, MEDICAL_LABELS, threshold=0.7) | |
logger.debug(f"Found {len(entities)} entities") | |
# Convert to our Entity model format | |
result = [] | |
for ent in entities: | |
if len(ent["text"]) <= 2: # Skip very short entities | |
continue | |
# Find the context (text surrounding the entity) | |
start_idx = pdf_text.find(ent["text"]) | |
if start_idx != -1: | |
# Get surrounding context (50 chars before and after) | |
context_start = max(0, start_idx - 50) | |
context_end = min(len(pdf_text), start_idx + len(ent["text"]) + 50) | |
context = pdf_text[context_start:context_end] | |
result.append(Entity( | |
entity=ent["text"], | |
context=context, | |
start=start_idx - context_start, # Adjust start position relative to context | |
end=start_idx - context_start + len(ent["text"]) | |
)) | |
logger.debug(f"Returning {len(result)} processed entities") | |
return result | |
except Exception as e: | |
logger.error(f"Error during extraction: {str(e)}", exc_info=True) | |
raise | |