import streamlit as st
import pandas as pd
import asyncio
import json
import io
from typing import List, Dict, Any
import plotly.graph_objects as go
import plotly.express as px
from neo4j import GraphDatabase
import networkx as nx
# Configure page
st.set_page_config(
page_title="Neo4j Knowledge Graph Builder",
page_icon="πΈοΈ",
layout="wide",
initial_sidebar_state="expanded"
)
# Custom CSS for better styling
st.markdown("""
""", unsafe_allow_html=True)
# Initialize session state
if 'driver' not in st.session_state:
st.session_state.driver = None
if 'kg_builder' not in st.session_state:
st.session_state.kg_builder = None
if 'graph_data' not in st.session_state:
st.session_state.graph_data = {'nodes': [], 'edges': []}
def init_neo4j_connection(uri: str, username: str, password: str):
"""Initialize Neo4j connection"""
try:
driver = GraphDatabase.driver(uri, auth=(username, password))
driver.verify_connectivity()
st.session_state.driver = driver
return True, "Connected successfully!"
except Exception as e:
return False, f"Connection failed: {str(e)}"
def init_kg_pipeline(openai_api_key: str, node_types: List[str],
relationship_types: List[str], patterns: List[tuple]):
"""Initialize Knowledge Graph Pipeline"""
try:
# Import here to avoid issues if packages aren't installed
from neo4j_graphrag.embeddings import OpenAIEmbeddings
from neo4j_graphrag.experimental.pipeline.kg_builder import SimpleKGPipeline
from neo4j_graphrag.llm import OpenAILLM
# Create embedder
embedder = OpenAIEmbeddings(
model="text-embedding-3-large",
api_key=openai_api_key
)
# Create LLM
llm = OpenAILLM(
model_name="gpt-4o",
model_params={
"max_tokens": 2000,
"response_format": {"type": "json_object"},
"temperature": 0,
},
api_key=openai_api_key
)
# Create pipeline
kg_builder = SimpleKGPipeline(
llm=llm,
driver=st.session_state.driver,
embedder=embedder,
schema={
"node_types": node_types,
"relationship_types": relationship_types,
"patterns": patterns,
},
on_error="IGNORE",
from_pdf=False,
)
st.session_state.kg_builder = kg_builder
return True, "Knowledge Graph Pipeline initialized successfully!"
except Exception as e:
return False, f"Pipeline initialization failed: {str(e)}"
def process_csv_data(df: pd.DataFrame, text_columns: List[str]) -> str:
"""Convert CSV data to text for knowledge graph processing"""
text_data = []
for _, row in df.iterrows():
row_text = []
for col in text_columns:
if col in df.columns and pd.notna(row[col]):
row_text.append(f"{col}: {row[col]}")
if row_text:
text_data.append(". ".join(row_text))
return "\n\n".join(text_data)
def read_uploaded_file(uploaded_file) -> str:
"""Read content from uploaded file"""
try:
if uploaded_file.type == "text/plain":
return str(uploaded_file.read(), "utf-8")
elif uploaded_file.type == "text/csv":
df = pd.read_csv(uploaded_file)
# Convert all columns to text representation
text_data = []
for _, row in df.iterrows():
row_items = []
for col in df.columns:
if pd.notna(row[col]):
row_items.append(f"{col}: {row[col]}")
if row_items:
text_data.append(". ".join(row_items))
return "\n\n".join(text_data)
elif uploaded_file.type == "application/json":
data = json.load(uploaded_file)
return json.dumps(data, indent=2)
else:
return f"File type {uploaded_file.type} is not supported for text extraction."
except Exception as e:
st.error(f"Error reading file: {str(e)}")
return ""
async def build_knowledge_graph(text: str):
"""Build knowledge graph from text"""
try:
await st.session_state.kg_builder.run_async(text=text)
return True, "Knowledge graph built successfully!"
except Exception as e:
return False, f"Failed to build knowledge graph: {str(e)}"
def fetch_graph_data():
"""Fetch graph data from Neo4j"""
try:
with st.session_state.driver.session() as session:
# Fetch nodes
nodes_query = """
MATCH (n)
RETURN id(n) as id, labels(n) as labels, properties(n) as properties
LIMIT 1000
"""
nodes_result = session.run(nodes_query)
nodes = []
for record in nodes_result:
node = {
'id': record['id'],
'labels': record['labels'],
'properties': record['properties']
}
nodes.append(node)
# Fetch relationships
edges_query = """
MATCH (a)-[r]->(b)
RETURN id(a) as source, id(b) as target, type(r) as relationship, properties(r) as properties
LIMIT 1000
"""
edges_result = session.run(edges_query)
edges = []
for record in edges_result:
edge = {
'source': record['source'],
'target': record['target'],
'relationship': record['relationship'],
'properties': record['properties']
}
edges.append(edge)
st.session_state.graph_data = {'nodes': nodes, 'edges': edges}
return True, f"Fetched {len(nodes)} nodes and {len(edges)} relationships"
except Exception as e:
return False, f"Failed to fetch graph data: {str(e)}"
def visualize_graph():
"""Create interactive graph visualization using Plotly"""
if not st.session_state.graph_data['nodes']:
st.warning("No graph data available. Please build a knowledge graph first.")
return
# Create NetworkX graph
G = nx.Graph()
# Add nodes
for node in st.session_state.graph_data['nodes']:
label = node['labels'][0] if node['labels'] else 'Unknown'
name = node['properties'].get('name', f"Node_{node['id']}")
G.add_node(node['id'], label=label, name=name, **node['properties'])
# Add edges
for edge in st.session_state.graph_data['edges']:
G.add_edge(edge['source'], edge['target'],
relationship=edge['relationship'], **edge['properties'])
if len(G.nodes()) == 0:
st.warning("No nodes to display")
return
# Generate layout
try:
pos = nx.spring_layout(G, k=1, iterations=50)
except:
pos = {node: (0, 0) for node in G.nodes()}
# Create edge traces
edge_x = []
edge_y = []
edge_info = []
for edge in G.edges():
x0, y0 = pos[edge[0]]
x1, y1 = pos[edge[1]]
edge_x.extend([x0, x1, None])
edge_y.extend([y0, y1, None])
relationship = G.edges[edge].get('relationship', 'CONNECTED_TO')
edge_info.append(f"{edge[0]} --[{relationship}]--> {edge[1]}")
edge_trace = go.Scatter(
x=edge_x, y=edge_y,
line=dict(width=2, color='#888'),
hoverinfo='none',
mode='lines'
)
# Create node traces
node_x = []
node_y = []
node_info = []
node_colors = []
node_labels = []
color_map = {
'Person': '#FF6B6B',
'House': '#4ECDC4',
'Planet': '#45B7D1',
'Organization': '#96CEB4',
'Location': '#FFEAA7'
}
for node in G.nodes():
x, y = pos[node]
node_x.append(x)
node_y.append(y)
node_data = G.nodes[node]
label = node_data.get('label', 'Unknown')
name = node_data.get('name', f'Node_{node}')
# Create hover info
info_parts = [f"ID: {node}", f"Type: {label}", f"Name: {name}"]
for key, value in node_data.items():
if key not in ['label', 'name'] and value:
info_parts.append(f"{key}: {value}")
node_info.append('
'.join(info_parts))
node_colors.append(color_map.get(label, '#DDA0DD'))
node_labels.append(name)
node_trace = go.Scatter(
x=node_x, y=node_y,
mode='markers+text',
hoverinfo='text',
text=node_labels,
textposition="middle center",
hovertext=node_info,
marker=dict(
size=30,
color=node_colors,
line=dict(width=2, color='white')
)
)
# Create figure
fig = go.Figure(data=[edge_trace, node_trace],
layout=go.Layout(
title='Knowledge Graph Visualization',
titlefont_size=16,
showlegend=False,
hovermode='closest',
margin=dict(b=20,l=5,r=5,t=40),
annotations=[ dict(
text="Hover over nodes for details",
showarrow=False,
xref="paper", yref="paper",
x=0.005, y=-0.002,
font=dict(color="#888888", size=12)
)],
xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
yaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
height=600
))
st.plotly_chart(fig, use_container_width=True)
# Display graph statistics
col1, col2, col3, col4 = st.columns(4)
with col1:
st.metric("Nodes", len(G.nodes()))
with col2:
st.metric("Edges", len(G.edges()))
with col3:
st.metric("Node Types", len(set(G.nodes[node].get('label', 'Unknown') for node in G.nodes())))
with col4:
density = nx.density(G) if len(G.nodes()) > 1 else 0
st.metric("Density", f"{density:.3f}")
def query_graph(query: str):
"""Execute Cypher query on the graph"""
try:
with st.session_state.driver.session() as session:
result = session.run(query)
records = list(result)
return True, records
except Exception as e:
return False, str(e)
def main():
st.title("πΈοΈ Neo4j Knowledge Graph Builder")
st.markdown("Build, visualize, and query knowledge graphs from various file types")
# Sidebar for configuration
with st.sidebar:
st.header("βοΈ Configuration")
# Neo4j Connection
st.subheader("Neo4j Database")
neo4j_uri = st.text_input("Neo4j URI", value="neo4j://localhost:7687")
neo4j_username = st.text_input("Username", value="neo4j")
neo4j_password = st.text_input("Password", type="password", value="password")
if st.button("Connect to Neo4j"):
success, message = init_neo4j_connection(neo4j_uri, neo4j_username, neo4j_password)
if success:
st.success(message)
else:
st.error(message)
# OpenAI Configuration
st.subheader("OpenAI Settings")
openai_api_key = st.text_input("OpenAI API Key", type="password")
# Schema Configuration
st.subheader("Knowledge Graph Schema")
# Node Types
node_types_input = st.text_area("Node Types (one per line)",
value="Person\nHouse\nPlanet\nOrganization\nLocation")
node_types = [nt.strip() for nt in node_types_input.split('\n') if nt.strip()]
# Relationship Types
rel_types_input = st.text_area("Relationship Types (one per line)",
value="PARENT_OF\nHEIR_OF\nRULES\nWORKS_FOR\nLIVES_IN")
relationship_types = [rt.strip() for rt in rel_types_input.split('\n') if rt.strip()]
# Patterns
st.subheader("Relationship Patterns")
patterns_input = st.text_area("Patterns (format: Source,Relationship,Target - one per line)",
value="Person,PARENT_OF,Person\nPerson,HEIR_OF,House\nHouse,RULES,Planet\nPerson,WORKS_FOR,Organization\nPerson,LIVES_IN,Location")
patterns = []
for pattern_line in patterns_input.split('\n'):
if pattern_line.strip():
parts = [p.strip() for p in pattern_line.split(',')]
if len(parts) == 3:
patterns.append(tuple(parts))
if st.button("Initialize KG Pipeline") and openai_api_key:
if st.session_state.driver:
success, message = init_kg_pipeline(openai_api_key, node_types, relationship_types, patterns)
if success:
st.success(message)
else:
st.error(message)
else:
st.error("Please connect to Neo4j first")
# Main content tabs
tab1, tab2, tab3, tab4 = st.tabs(["π Data Input", "πΈοΈ Graph Visualization", "π Query Graph", "π Analytics"])
with tab1:
st.header("Data Input & Knowledge Graph Creation")
# File upload
uploaded_files = st.file_uploader(
"Upload files for knowledge graph creation",
type=['txt', 'csv', 'json', 'pdf'],
accept_multiple_files=True
)
# Text input
text_input = st.text_area("Or paste text directly:", height=200)
# CSV specific options
if uploaded_files:
csv_files = [f for f in uploaded_files if f.type == "text/csv"]
if csv_files:
st.subheader("CSV Processing Options")
for csv_file in csv_files:
df = pd.read_csv(csv_file)
st.write(f"**{csv_file.name}** - Shape: {df.shape}")
st.write("Preview:")
st.dataframe(df.head())
# Process button
if st.button("π Build Knowledge Graph", type="primary"):
if not st.session_state.kg_builder:
st.error("Please initialize the KG Pipeline first")
else:
all_text = ""
# Process uploaded files
if uploaded_files:
for uploaded_file in uploaded_files:
file_content = read_uploaded_file(uploaded_file)
all_text += f"\n\n--- From {uploaded_file.name} ---\n{file_content}"
# Add direct text input
if text_input.strip():
all_text += f"\n\n--- Direct Input ---\n{text_input.strip()}"
if all_text.strip():
with st.spinner("Building knowledge graph..."):
try:
# Run async function
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
success, message = loop.run_until_complete(build_knowledge_graph(all_text))
loop.close()
if success:
st.success(message)
# Fetch updated graph data
fetch_success, fetch_message = fetch_graph_data()
if fetch_success:
st.info(fetch_message)
else:
st.warning(f"Graph built but couldn't fetch data: {fetch_message}")
else:
st.error(message)
except Exception as e:
st.error(f"Error: {str(e)}")
else:
st.warning("Please provide some text or upload files to process")
with tab2:
st.header("Knowledge Graph Visualization")
col1, col2 = st.columns([3, 1])
with col2:
if st.button("π Refresh Graph Data"):
if st.session_state.driver:
success, message = fetch_graph_data()
if success:
st.success(message)
else:
st.error(message)
else:
st.error("Please connect to Neo4j first")
with col1:
st.markdown("### Interactive Graph")
# Visualize the graph
visualize_graph()
# Display raw data
if st.session_state.graph_data['nodes']:
with st.expander("View Raw Graph Data"):
col1, col2 = st.columns(2)
with col1:
st.subheader("Nodes")
st.json(st.session_state.graph_data['nodes'][:5]) # Show first 5
with col2:
st.subheader("Relationships")
st.json(st.session_state.graph_data['edges'][:5]) # Show first 5
with tab3:
st.header("Query Knowledge Graph")
# Predefined queries
st.subheader("Quick Queries")
col1, col2, col3 = st.columns(3)
with col1:
if st.button("All Nodes"):
query = "MATCH (n) RETURN n LIMIT 25"
st.code(query, language="cypher")
if st.session_state.driver:
success, result = query_graph(query)
if success:
st.write(f"Found {len(result)} nodes")
for record in result[:10]: # Show first 10
st.write(record['n'])
with col2:
if st.button("All Relationships"):
query = "MATCH (a)-[r]->(b) RETURN a.name, type(r), b.name LIMIT 25"
st.code(query, language="cypher")
if st.session_state.driver:
success, result = query_graph(query)
if success:
df = pd.DataFrame([dict(record) for record in result])
st.dataframe(df)
with col3:
if st.button("Node Statistics"):
query = "MATCH (n) RETURN labels(n)[0] as type, count(*) as count ORDER BY count DESC"
st.code(query, language="cypher")
if st.session_state.driver:
success, result = query_graph(query)
if success:
df = pd.DataFrame([dict(record) for record in result])
st.bar_chart(df.set_index('type'))
# Custom query
st.subheader("Custom Cypher Query")
custom_query = st.text_area("Enter your Cypher query:",
value="MATCH (n) RETURN n LIMIT 10",
height=100)
if st.button("Execute Query"):
if st.session_state.driver and custom_query.strip():
success, result = query_graph(custom_query)
if success:
st.success(f"Query executed successfully! Found {len(result)} results")
if result:
# Try to convert to DataFrame for better display
try:
df = pd.DataFrame([dict(record) for record in result])
st.dataframe(df)
except:
# Fall back to raw display
for i, record in enumerate(result[:20]): # Limit display
st.write(f"**Result {i+1}:**")
st.write(dict(record))
if i > 20:
st.write(f"... and {len(result) - 20} more results")
break
else:
st.error(f"Query failed: {result}")
else:
st.warning("Please connect to Neo4j and enter a query")
with tab4:
st.header("Graph Analytics")
if st.session_state.graph_data['nodes']:
# Basic statistics
nodes = st.session_state.graph_data['nodes']
edges = st.session_state.graph_data['edges']
col1, col2, col3, col4 = st.columns(4)
with col1:
st.metric("Total Nodes", len(nodes))
with col2:
st.metric("Total Edges", len(edges))
with col3:
node_types = [n['labels'][0] if n['labels'] else 'Unknown' for n in nodes]
st.metric("Unique Node Types", len(set(node_types)))
with col4:
rel_types = [e['relationship'] for e in edges]
st.metric("Unique Relationships", len(set(rel_types)))
# Node type distribution
st.subheader("Node Type Distribution")
node_type_counts = pd.Series(node_types).value_counts()
fig = px.pie(values=node_type_counts.values, names=node_type_counts.index)
st.plotly_chart(fig)
# Relationship type distribution
st.subheader("Relationship Type Distribution")
rel_type_counts = pd.Series(rel_types).value_counts()
fig = px.bar(x=rel_type_counts.index, y=rel_type_counts.values)
fig.update_layout(xaxis_title="Relationship Type", yaxis_title="Count")
st.plotly_chart(fig)
else:
st.info("No graph data available. Please build a knowledge graph first.")
if __name__ == "__main__":
main()