Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
import os | |
from urllib.request import urlretrieve | |
import requests | |
import random | |
import time | |
import pandas as pd | |
import xml.etree.ElementTree as ET | |
from tqdm import tqdm | |
import os | |
import requests | |
from tqdm import tqdm | |
import xml.etree.ElementTree as ET | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
import torch | |
import numpy as np | |
token = os.getenv("githubToken") | |
def noop_logger(*args, **kwargs): | |
pass | |
def download_pdf(paper, max_retries=3): | |
if pd.isna(paper.pdf_url): | |
paper.log("ERROR", "Missing PDF URL") | |
return paper | |
pdf_path = paper.pdf_path | |
if (os.path.exists(pdf_path)): | |
return paper | |
headers = {'User-Agent': 'Mozilla/5.0'} | |
for attempt in range(max_retries): | |
try: | |
response = requests.get(paper.pdf_url, headers=headers) | |
if response.status_code == 200: | |
with open(pdf_path, "wb") as f: | |
f.write(response.content) | |
time.sleep(random.uniform(1.0, 3.0)) | |
return paper | |
elif response.status_code == 429: | |
wait = 2 ** attempt | |
paper.log("WARNING", f"Rate limited, retrying in {wait}s...") | |
time.sleep(wait) | |
else: | |
paper.log("ERROR", f"Download failed: HTTP {response.status_code}") | |
break | |
except Exception as e: | |
paper.log("ERROR", f"Download error: {e}") | |
time.sleep(1) | |
return paper | |
def get_api_link(url): | |
username, repo_name = decompose_url(url) | |
if (username == None): | |
return "" | |
return f"https://api.github.com/repos/{username}/{repo_name}/zipball/" | |
def decompose_url(url): | |
try: | |
url = url.split("github.com")[1] | |
url = url.strip(".") | |
url = url.split(".git")[0] | |
url = url.strip("/") | |
parts = url.split("/") | |
username = parts[0] | |
repo_name = parts[1] | |
return username, repo_name | |
except: | |
return None, None | |
def fetch_repo(repo_url, repo_name, token, force_download=False): | |
if (os.path.exists(repo_name) & (not force_download)): | |
return | |
if ("github.com" not in repo_url): | |
return ValueError(f"URL not for github repo, please evaluate manually ({repo_url}).") | |
headers = {"Authorization": f"token {token}"} | |
api_url = get_api_link(repo_url) | |
if (api_url == ""): | |
return ValueError(f"Failed to parse the URL, please evaluate manually ({repo_url}).") | |
# Sending GET request to GitHub API | |
response = requests.get(api_url, headers=headers) | |
if response.status_code == 200: | |
with open(repo_name, 'wb') as file: | |
file.write(response.content) | |
if (response.status_code == 404): | |
return ValueError("Repository private / Link broken.") | |
def download_repo(paper): | |
try: | |
if (paper.main_repo_url is None): | |
return paper | |
fetch_repo(0, paper.main_repo_url, paper.zip_path, token) | |
except Exception as e: | |
paper.log("ERROR", f"Repo download failed: {e}") | |
return paper | |
def pdf_to_grobid(filename, save_path=None, grobid_url="https://attilasimko-grobid.hf.space/", force_download=False): | |
""" | |
Convert a PDF to Grobid XML. | |
Parameters: | |
filename (str or list): Path to the PDF file or list of PDF files. | |
save_path (str, optional): Directory or file path to save to. Defaults to the current directory. | |
grobid_url (str, optional): URL of the Grobid server. Defaults to public server. | |
Returns: | |
str or list: Path(s) to the saved XML file(s) or parsed XML object if saved to a temp file. | |
""" | |
# Determine save path | |
if save_path is None: | |
save_file = os.path.join(os.getcwd(), "temp_grobid.xml") | |
elif os.path.isdir(save_path): | |
base_name = os.path.splitext(os.path.basename(filename))[0] + ".xml" | |
save_file = os.path.join(save_path, base_name) | |
else: | |
save_file = save_path if save_path.endswith(".xml") else save_path + ".xml" | |
if (os.path.exists(save_file) & (not force_download)): | |
return | |
def is_server_up(url): | |
try: | |
response = requests.get(url + "/api/health", timeout=5) | |
return response.status_code == 200 | |
except requests.RequestException: | |
return False | |
if not is_server_up(grobid_url): | |
raise ConnectionError(f"The Grobid server {grobid_url} is not available.") | |
# Handle multiple files | |
if isinstance(filename, list): | |
if save_path is None or not os.path.isdir(save_path): | |
print(f"Warning: {save_path} is not a directory. PDFs will be saved in the current directory: {os.getcwd()}") | |
save_path = "." | |
xmls = [] | |
for pdf in tqdm(filename, desc="Processing PDFs"): | |
try: | |
xml = pdf_to_grobid(pdf, save_path, grobid_url) | |
xmls.append(xml) | |
except Exception as e: | |
print(f"Error processing {pdf}: {e}") | |
xmls.append(None) | |
return xmls | |
# Handle directory input | |
if os.path.isdir(filename): | |
pdfs = [os.path.join(filename, f) for f in os.listdir(filename) if f.endswith(".pdf")] | |
if not pdfs: | |
print(f"Warning: No PDF files found in directory {filename}") | |
return pdf_to_grobid(pdfs, save_path, grobid_url) | |
# Ensure file exists | |
if not os.path.isfile(filename): | |
raise FileNotFoundError(f"The file {filename} does not exist.") | |
# Send PDF to Grobid | |
with open(filename, "rb") as file: | |
files = {"input": file} | |
post_url = f"{grobid_url}/api/processFulltextDocument" | |
response = requests.post(post_url, files=files) | |
if response.status_code != 200: | |
# os.remove(filename) | |
raise Exception(f"Error: {response.reason} PDF removed.") | |
# Save the response | |
with open(save_file, "wb") as f: | |
f.write(response.content) | |
# Return XML object if saved to temp file | |
if save_path is None: | |
return ET.parse(save_file).getroot() | |
else: | |
return save_file | |
def extract_body(xml_root): | |
"""Extracts and returns the text content of the paper's body from Grobid XML.""" | |
namespace = {"tei": "http://www.tei-c.org/ns/1.0"} # Define TEI namespace | |
body_text = [] | |
# Locate <body> in the XML structure | |
body = xml_root.find(".//tei:body", namespace) | |
if body is not None: | |
for p in body.findall(".//tei:p", namespace): # Get all paragraphs inside <body> | |
if p.text: | |
body_text.append(p.text.strip()) | |
return "\n".join(body_text) | |