LegalDoc / app.py
hymarog1's picture
Update app.py
0925ca5 verified
raw
history blame
27.1 kB
import streamlit as st
import shelve
import docx2txt
import PyPDF2
import time # Used to simulate typing effect
import nltk
import re
import os
import time # already imported in your code
from dotenv import load_dotenv
import torch
from sentence_transformers import SentenceTransformer, util
nltk.download('punkt')
import hashlib
from nltk import sent_tokenize
nltk.download('punkt_tab')
from transformers import LEDTokenizer, LEDForConditionalGeneration
from transformers import pipeline
import asyncio
import dateutil.parser
from datetime import datetime
import sys
from openai import OpenAI
import numpy as np
# Fix for RuntimeError: no running event loop on Windows
if sys.platform.startswith("win"):
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
st.set_page_config(page_title="Legal Document Summarizer", layout="wide")
if "processed" not in st.session_state:
st.session_state.processed = False
if "last_uploaded_hash" not in st.session_state:
st.session_state.last_uploaded_hash = None
if "chat_prompt_processed" not in st.session_state:
st.session_state.chat_prompt_processed = False
if "embedding_text" not in st.session_state:
st.session_state.embedding_text = None
if "document_context" not in st.session_state:
st.session_state.document_context = None
if "last_prompt_hash" not in st.session_state:
st.session_state.last_prompt_hash = None
st.title("πŸ“„ Legal Document Summarizer (Alt Model w/o token doc Aug)")
USER_AVATAR = "πŸ‘€"
BOT_AVATAR = "πŸ€–"
# Load chat history
def load_chat_history():
with shelve.open("chat_history") as db:
return db.get("messages", [])
# Save chat history
def save_chat_history(messages):
with shelve.open("chat_history") as db:
db["messages"] = messages
# Function to limit text preview to 500 words
def limit_text(text, word_limit=500):
words = text.split()
return " ".join(words[:word_limit]) + ("..." if len(words) > word_limit else "")
# CLEAN AND NORMALIZE TEXT
def clean_text(text):
# Remove newlines and extra spaces
text = text.replace('\r\n', ' ').replace('\n', ' ')
text = re.sub(r'\s+', ' ', text)
# Remove page number markers like "Page 1 of 10"
text = re.sub(r'Page\s+\d+\s+of\s+\d+', '', text, flags=re.IGNORECASE)
# Remove long dashed or underscored lines
text = re.sub(r'[_]{5,}', '', text) # Lines with underscores: _____
text = re.sub(r'[-]{5,}', '', text) # Lines with hyphens: -----
# Remove long dotted separators
text = re.sub(r'[.]{4,}', '', text) # Dots like "......" or ".............."
# Trim final leading/trailing whitespace
text = text.strip()
return text
#######################################################################################################################
# LOADING MODELS FOR DIVIDING TEXT INTO SECTIONS
# Load token from .env file
load_dotenv()
HF_API_TOKEN = os.getenv("HF_API_TOKEN")
client = OpenAI(
base_url="https://api.studio.nebius.com/v1/",
api_key=os.getenv("OPENAI_API_KEY")
)
# print("API Key:", os.getenv("OPENAI_API_KEY")) # Temporary for debugging
# Load once at the top (cache for performance)
@st.cache_resource
def load_local_zero_shot_classifier():
return pipeline("zero-shot-classification", model="typeform/distilbert-base-uncased-mnli")
local_classifier = load_local_zero_shot_classifier()
SECTION_LABELS = ["Facts", "Arguments", "Judgement", "Others"]
def classify_chunk(text):
result = local_classifier(text, candidate_labels=SECTION_LABELS)
return result["labels"][0]
# NEW: NLP-based sectioning using zero-shot classification
def section_by_zero_shot(text):
sections = {"Facts": "", "Arguments": "", "Judgment": "", "Others": ""}
sentences = sent_tokenize(text)
chunk = ""
for i, sent in enumerate(sentences):
chunk += sent + " "
if (i + 1) % 3 == 0 or i == len(sentences) - 1:
label = classify_chunk(chunk.strip())
print(f"πŸ”Ž Chunk: {chunk[:60]}...\nπŸ”– Predicted Label: {label}")
# πŸ‘‡ Normalize label (title case and fallback)
label = label.capitalize()
if label not in sections:
label = "Others"
sections[label] += chunk + "\n"
chunk = ""
return sections
#######################################################################################################################
# EXTRACTING TEXT FROM UPLOADED FILES
# Function to extract text from uploaded file
def extract_text(file):
if file.name.endswith(".pdf"):
reader = PyPDF2.PdfReader(file)
full_text = "\n".join(page.extract_text() or "" for page in reader.pages)
elif file.name.endswith(".docx"):
full_text = docx2txt.process(file)
elif file.name.endswith(".txt"):
full_text = file.read().decode("utf-8")
else:
return "Unsupported file type."
return full_text # Full text is needed for summarization
#######################################################################################################################
# EXTRACTIVE AND ABSTRACTIVE SUMMARIZATION
@st.cache_resource
def load_legalbert():
return SentenceTransformer("nlpaueb/legal-bert-base-uncased")
legalbert_model = load_legalbert()
@st.cache_resource
def load_led():
tokenizer = LEDTokenizer.from_pretrained("allenai/led-base-16384")
model = LEDForConditionalGeneration.from_pretrained("allenai/led-base-16384")
return tokenizer, model
tokenizer_led, model_led = load_led()
from transformers import pipeline
@st.cache_resource
def load_led_summarizer():
# Use β€œallenai/led-base-16384” (or β€œled-large-16384”)
return pipeline(
"summarization",
model="allenai/led-base-16384",
tokenizer="allenai/led-base-16384",
device=0 if torch.cuda.is_available() else -1
)
led_summarizer = load_led_summarizer()
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
@st.cache_resource
def load_paraphraser():
tok = AutoTokenizer.from_pretrained("google/flan-t5-small")
model = AutoModelForSeq2SeqLM.from_pretrained("google/flan-t5-small")
return pipeline(
"text2text-generation",
model=model,
tokenizer=tok,
device=0 if torch.cuda.is_available() else -1,
max_length=256,
num_beams=4,
do_sample=False
)
paraphraser = load_paraphraser()
def humanize(text):
out = paraphraser(f"paraphrase: {text}",
max_length=256,
num_beams=4,
do_sample=False)[0]["generated_text"]
return out
# then at the end of rag_query_response:
def legalbert_extractive_summary(text, top_ratio=0.2):
sentences = sent_tokenize(text)
top_k = max(3, int(len(sentences) * top_ratio))
if len(sentences) <= top_k:
return text
sentence_embeddings = legalbert_model.encode(sentences, convert_to_tensor=True)
doc_embedding = torch.mean(sentence_embeddings, dim=0)
cosine_scores = util.pytorch_cos_sim(doc_embedding, sentence_embeddings)[0]
top_results = torch.topk(cosine_scores, k=top_k)
selected_sentences = [sentences[i] for i in sorted(top_results.indices.tolist())]
return " ".join(selected_sentences)
# Add LED Abstractive Summarization
def led_abstractive_summary(text, max_length=512, min_length=100):
inputs = tokenizer_led(
text, return_tensors="pt", padding="max_length",
truncation=True, max_length=4096
)
global_attention_mask = torch.zeros_like(inputs["input_ids"])
global_attention_mask[:, 0] = 1
outputs = model_led.generate(
inputs["input_ids"],
attention_mask=inputs["attention_mask"],
global_attention_mask=global_attention_mask,
max_length=max_length,
min_length=min_length,
num_beams=4, # Use beam search
repetition_penalty=2.0, # Penalize repetition
length_penalty=1.0,
early_stopping=True,
no_repeat_ngram_size=4 # Prevent repeated phrases
)
return tokenizer_led.decode(outputs[0], skip_special_tokens=True)
def led_abstractive_summary_chunked(text, max_tokens=3000):
sentences = sent_tokenize(text)
current_chunk, chunks, summaries = "", [], []
for sent in sentences:
if len(tokenizer_led(current_chunk + sent)["input_ids"]) > max_tokens:
chunks.append(current_chunk)
current_chunk = sent
else:
current_chunk += " " + sent
if current_chunk:
chunks.append(current_chunk)
for chunk in chunks:
inputs = tokenizer_led(chunk, return_tensors="pt", padding="max_length", truncation=True, max_length=4096)
global_attention_mask = torch.zeros_like(inputs["input_ids"])
global_attention_mask[:, 0] = 1
output = model_led.generate(
inputs["input_ids"],
attention_mask=inputs["attention_mask"],
global_attention_mask=global_attention_mask,
max_length=512,
min_length=100,
num_beams=4,
repetition_penalty=2.0,
length_penalty=1.0,
early_stopping=True,
no_repeat_ngram_size=4,
)
summaries.append(tokenizer_led.decode(output[0], skip_special_tokens=True))
return " ".join(summaries)
def hybrid_summary_hierarchical(text, top_ratio=0.8):
cleaned_text = clean_text(text)
sections = section_by_zero_shot(cleaned_text)
structured_summary = {} # <-- hierarchical summary here
for name, content in sections.items():
if content.strip():
# Extractive summary
extractive = legalbert_extractive_summary(content, top_ratio)
# Abstractive summary
abstractive = led_abstractive_summary_chunked(extractive)
# Store in dictionary (hierarchical structure)
structured_summary[name] = {
"extractive": extractive,
"abstractive": abstractive
}
return structured_summary
from sentence_transformers import SentenceTransformer
@st.cache_resource
def load_embedder():
return SentenceTransformer("all-MiniLM-L6-v2")
embedder = load_embedder()
import numpy as np
def chunk_text_custom(text, n=1000, overlap=200):
chunks = []
for i in range(0, len(text), n - overlap):
chunks.append(text[i:i + n])
return chunks
def get_embedding(text, model="BAAI/bge-en-icl"):
"""
From your notebook:
Creates an embedding for the given text chunk using the BGE-ICL model.
"""
resp = client.embeddings.create(model=model, input=text)
return np.array(resp.data[0].embedding)
def create_embeddings(text_chunks, model="BAAI/bge-en-icl"):
"""
Batch the get_embedding call over your chunks.
Returns a list of numpy arrays.
"""
return [get_embedding(chunk, model=model) for chunk in text_chunks]
def generate_questions(text_chunk, num_questions=5):
"""
Use LED to generate a small set of probing questions
about this chunk that the final answer should address.
"""
prompt = (
"You are a question-generation expert. "
"From the text below, generate "
f"{num_questions} concise questions:\n\n{text_chunk}"
)
out = led_summarizer(
prompt,
max_length=128,
min_length=32,
num_beams=4,
do_sample=False
)[0]["summary_text"]
# assume each question on its own line
questions = [q.strip() for q in out.split("\n") if q.strip()]
return questions[:num_questions]
def process_document(document_text):
"""
1) chunk the document
2) embed each chunk with your SentenceTransformer
returns chunks, embeddings
"""
chunks = chunk_text_custom(document_text, n=800, overlap=200)
embeddings = embedder.encode(chunks, convert_to_tensor=False)
return chunks, embeddings
def semantic_search(query, chunks, chunk_embeddings, k=5):
"""
Score each chunk by cosine similarity to the query embed
and return the top-k chunks (in descending order).
"""
q_emb = embedder.encode([query], convert_to_tensor=False)[0]
scores = [
float(np.dot(q_emb, emb) / (np.linalg.norm(q_emb) * np.linalg.norm(emb)))
for emb in chunk_embeddings
]
top_idxs = sorted(range(len(scores)), key=lambda i: scores[i], reverse=True)[:k]
return [chunks[i] for i in top_idxs]
def prepare_context(questions, chunks, chunk_embeddings, k_per_question=2):
"""
For each generated question, pick its top-k supporting chunks,
then dedupe & concatenate into one context string.
"""
selected = []
for q in questions:
best = semantic_search(q, chunks, chunk_embeddings, k=k_per_question)
selected.extend(best)
# dedupe while preserving order
seen = set()
context = []
for c in selected:
if c not in seen:
seen.add(c)
context.append(c)
return "\n\n".join(f"β€’ {c}" for c in context)
def rag_query_response(prompt, document_text):
"""
Document-Augmentation RAG:
1. generate probing sub-questions about the doc
2. process the doc (chunk + embed)
3. build minimal context via those questions
4. feed context + user prompt into LED
5. paraphrase (humanize)
"""
# 1) Probing questions
questions = generate_questions(document_text, num_questions=5)
# 2) Chunk & embed the document
chunks, chunk_embs = process_document(document_text)
# 3) Assemble the distilled context
context = prepare_context(questions, chunks, chunk_embs, k_per_question=2)
# 4) Compose the LED input
led_input = (
"You are a knowledgeable legal assistant. "
"Answer the user’s question **using ONLY** the context below, "
"and speak in a friendly, conversational tone.\n\n"
f"Context:\n{context}\n\n"
f"Question: {prompt}\n\nAnswer:"
)
raw = led_summarizer(
led_input,
max_length=512,
min_length=64,
do_sample=False
)[0]["summary_text"]
# 5) Humanize
return humanize(raw)
#######################################################################################################################
# STREAMLIT APP INTERFACE CODE
# Initialize or load chat history
if "messages" not in st.session_state:
st.session_state.messages = load_chat_history()
# Initialize last_uploaded if not set
if "last_uploaded" not in st.session_state:
st.session_state.last_uploaded = None
# Sidebar with a button to delete chat history
with st.sidebar:
st.subheader("βš™οΈ Options")
if st.button("Delete Chat History"):
st.session_state.messages = []
st.session_state.last_uploaded = None
st.session_state.processed = False
st.session_state.chat_prompt_processed = False
save_chat_history([])
# Display chat messages with a typing effect
def display_with_typing_effect(text, speed=0.005):
placeholder = st.empty()
displayed_text = ""
for char in text:
displayed_text += char
placeholder.markdown(displayed_text)
time.sleep(speed)
return displayed_text
# Show existing chat messages
for message in st.session_state.messages:
avatar = USER_AVATAR if message["role"] == "user" else BOT_AVATAR
with st.chat_message(message["role"], avatar=avatar):
st.markdown(message["content"])
# Standard chat input field
prompt = st.chat_input("Type a message...")
# Place uploader before the chat so it's always visible
with st.container():
st.subheader("πŸ“Ž Upload a Legal Document")
uploaded_file = st.file_uploader("Upload a file (PDF, DOCX, TXT)", type=["pdf", "docx", "txt"])
reprocess_btn = st.button("πŸ”„ Reprocess Last Uploaded File")
# Hashing logic
def get_file_hash(file):
file.seek(0)
content = file.read()
file.seek(0)
return hashlib.md5(content).hexdigest()
# Function to prepare text for embedding
# This function combines the extractive and abstractive summaries into a single string for embedding
def prepare_text_for_embedding(summary_dict):
combined_chunks = []
for section, content in summary_dict.items():
ext = content.get("extractive", "").strip()
abs = content.get("abstractive", "").strip()
if ext:
combined_chunks.append(f"{section} - Extractive Summary:\n{ext}")
if abs:
combined_chunks.append(f"{section} - Abstractive Summary:\n{abs}")
return "\n\n".join(combined_chunks)
###################################################################################################################
# Store cleaned text and FAISS index only when document is processed
# Embedding for chunking
##############################################################################################################
user_role = st.sidebar.selectbox(
"🎭 Select Your Role for Custom Summary",
["General", "Judge", "Lawyer", "Student"]
)
def role_based_filter(section, summary, role):
if role == "General":
return summary
filtered_summary = {
"extractive": "",
"abstractive": ""
}
if role == "Judge" and section in ["Judgement", "Facts"]:
filtered_summary = summary
elif role == "Lawyer" and section in ["Arguments", "Facts"]:
filtered_summary = summary
elif role == "Student" and section in ["Facts"]:
filtered_summary = summary
return filtered_summary
#########################################################################################################################
if uploaded_file:
file_hash = get_file_hash(uploaded_file)
if file_hash != st.session_state.last_uploaded_hash or reprocess_btn:
st.session_state.processed = False
# if is_new_file or reprocess_btn:
# st.session_state.processed = False
if not st.session_state.processed:
start_time = time.time()
raw_text = extract_text(uploaded_file)
summary_dict = hybrid_summary_hierarchical(raw_text)
# timeline_data = extract_timeline(clean_text(raw_text))
embedding_text = prepare_text_for_embedding(summary_dict)
# Generate and display RAG-based summary
st.session_state.document_context = embedding_text
role_specific_prompt = f"As a {user_role}, summarize the legal document focusing on the most relevant aspects such as facts, arguments, and judgments tailored for your role. Include key legal reasoning and timeline of events where necessary."
rag_summary = rag_query_response(role_specific_prompt, embedding_text)
st.session_state.messages.append({"role": "user", "content": f"πŸ“€ Uploaded **{uploaded_file.name}**"})
st.session_state.messages.append({"role": "assistant", "content": rag_summary})
with st.chat_message("assistant", avatar=BOT_AVATAR):
display_with_typing_effect(rag_summary)
processing_time = round((time.time() - start_time) / 60, 2)
st.info(f"⏱️ Response generated in **{processing_time} minutes**.")
st.session_state.generated_summary = rag_summary #for Evalution
st.session_state.last_uploaded_hash = file_hash
st.session_state.processed = True
st.session_state.last_prompt_hash = None
save_chat_history(st.session_state.messages)
if prompt:
words = prompt.split()
word_count = len(words)
prompt_hash = hashlib.md5(prompt.encode("utf-8")).hexdigest()
# 1) LONG prompts – echo first, then summarize
if word_count > 30 and prompt_hash != st.session_state.last_prompt_hash:
# mark new prompt
st.session_state.last_prompt_hash = prompt_hash
# raw_text is just the prompt text
raw_text = prompt
st.session_state.messages.append({
"role": "user",
"content": f"πŸ“₯ **Pasted Document Text:**\n\n{limit_text(raw_text, word_limit=500)}"
})
with st.chat_message("user", avatar=USER_AVATAR):
st.markdown(limit_text(raw_text, word_limit=500))
start_time = time.time()
summary_dict = hybrid_summary_hierarchical(raw_text)
emb_text = prepare_text_for_embedding(summary_dict)
st.session_state.document_context = emb_text
st.session_state.processed = True
role_prompt = (
f"As a {user_role}, summarize the document focusing on facts, "
"arguments, judgments, plus timeline of events."
)
initial_summary = rag_query_response(role_prompt, emb_text)
# 3️⃣ Append & display the assistant’s summary with typing effect
st.session_state.messages.append({
"role": "assistant",
"content": initial_summary
})
with st.chat_message("assistant", avatar=BOT_AVATAR):
display_with_typing_effect(initial_summary)
st.info(f"⏱️ Summary generated in {round((time.time()-start_time)/60,2)} minutes")
save_chat_history(st.session_state.messages)
# 2) SHORT prompts: normal RAG against last context
elif word_count <= 30 and st.session_state.processed:
role_query = f"As a {user_role}, {prompt}"
answer = rag_query_response(role_query, st.session_state.document_context)
answer = rag_query_response(prompt, st.session_state.document_context)
st.session_state.messages.append({"role": "user", "content": prompt})
st.session_state.messages.append({"role": "assistant","content": answer})
with st.chat_message("assistant", avatar=BOT_AVATAR):
display_with_typing_effect(answer)
save_chat_history(st.session_state.messages)
# 3) Ingest prompt to start
else:
with st.chat_message("assistant", avatar=BOT_AVATAR):
st.markdown("❗ Paste at least 30 words of your document to ingest it first.")
################################Evaluation###########################
# πŸ“š Imports
import evaluate
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
import streamlit as st
# πŸ“Œ Load Evaluators Once
@st.cache_resource
def load_evaluators():
rouge = evaluate.load("rouge")
bertscore = evaluate.load("bertscore")
return rouge, bertscore
rouge, bertscore = load_evaluators()
# πŸ“Œ Define Evaluation Functions
def evaluate_summary(generated_summary, ground_truth_summary):
"""Evaluate ROUGE and BERTScore."""
rouge_result = rouge.compute(predictions=[generated_summary], references=[ground_truth_summary])
bert_result = bertscore.compute(predictions=[generated_summary], references=[ground_truth_summary], lang="en")
return rouge_result, bert_result
def compute_bleu(prediction, ground_truth):
"""Compute BLEU score for summaries."""
reference = [ground_truth.strip().split()]
candidate = prediction.strip().split()
smoothie = SmoothingFunction().method4
return sentence_bleu(reference, candidate, smoothing_function=smoothie)
# πŸ“₯ Upload and Evaluate
ground_truth_summary_file = st.file_uploader("πŸ“„ Upload Ground Truth Summary (.txt)", type=["txt"])
if ground_truth_summary_file:
ground_truth_summary = ground_truth_summary_file.read().decode("utf-8").strip()
if "generated_summary" in st.session_state and st.session_state.generated_summary:
prediction = st.session_state.generated_summary
# Evaluate ROUGE and BERTScore
rouge_result, bert_result = evaluate_summary(prediction, ground_truth_summary)
# Display ROUGE and BERTScore
st.subheader("πŸ“Š Evaluation Results")
st.write("πŸ”Ή ROUGE Scores:")
st.json(rouge_result)
st.write("πŸ”Ή BERTScore:")
st.json(bert_result)
# Compute and Display BLEU Score
bleu = compute_bleu(prediction, ground_truth_summary)
st.subheader("πŸ”΅ BLEU Score")
st.write(f"BLEU Score: {bleu:.4f}")
else:
st.warning("⚠️ Please generate a summary first by uploading a document.")
######################################################################################################################
# Run this along with streamlit run app.py to evaluate the model's performance on a test set
# Otherwise, comment the below code
# β‡’ EVALUATION HOOK: after the very first summary, fire off evaluate.main() once
# import json
# import pandas as pd
# import threading
# def run_eval(doc_context):
# with open("test_case1.json", "r", encoding="utf-8") as f:
# gt_data = json.load(f)
# # 2) map document_id β†’ local file
# records = []
# for entry in gt_data:
# doc_id = entry["document_id"]
# query = entry["query"]
# gt_ans = entry["ground_truth_answer"]
# # model_ans = rag_query_response(query, emb_text)
# model_ans = rag_query_response(query, doc_context)
# records.append({
# "document_id": doc_id,
# "query": query,
# "ground_truth_answer": gt_ans,
# "model_answer": model_ans
# })
# print(f"βœ… Done {doc_id} / β€œ{query}”")
# # 3) push to DataFrame + CSV
# df = pd.DataFrame(records)
# out = "evaluation_results.csv"
# df.to_csv(out, index=False, encoding="utf-8")
# print(f"\nπŸ“ Saved {len(df)} rows to {out}")
# # you could log this somewhere
# def _run_evaluation():
# try:
# run_eval()
# except Exception as e:
# print("‼️ Evaluation script error:", e)
# if st.session_state.processed and not st.session_state.get("evaluation_launched", False):
# st.session_state.evaluation_launched = True
# # inform user
# st.sidebar.info("πŸ”¬ Starting background evaluation run…")
# # *capture* the context
# doc_ctx = st.session_state.document_context
# # spawn the thread, passing doc_ctx in
# threading.Thread(
# target=lambda: run_eval(doc_ctx),
# daemon=True
# ).start()
# st.sidebar.success("βœ… Evaluation launched β€” check evaluation_results.csv when done.")
# # check for file existence & show download button
# eval_path = os.path.abspath("evaluation_results.csv")
# if os.path.exists(eval_path):
# st.sidebar.success(f"βœ… Results saved to:\n`{eval_path}`")
# # load it into a small dataframe (optional)
# df_eval = pd.read_csv(eval_path)
# # add a download button
# st.sidebar.download_button(
# label="⬇️ Download evaluation_results.csv",
# data=df_eval.to_csv(index=False).encode("utf-8"),
# file_name="evaluation_results.csv",
# mime="text/csv"
# )
# else:
# # if you want, display the cwd so you can inspect it
# st.sidebar.info(f"Current working dir:\n`{os.getcwd()}`")