import streamlit as st import pandas as pd import asyncio import json import io from typing import List, Dict, Any import plotly.graph_objects as go import plotly.express as px from neo4j import GraphDatabase import networkx as nx # Configure page st.set_page_config( page_title="Neo4j Knowledge Graph Builder", page_icon="πŸ•ΈοΈ", layout="wide", initial_sidebar_state="expanded" ) # Custom CSS for better styling st.markdown(""" """, unsafe_allow_html=True) # Initialize session state if 'driver' not in st.session_state: st.session_state.driver = None if 'kg_builder' not in st.session_state: st.session_state.kg_builder = None if 'graph_data' not in st.session_state: st.session_state.graph_data = {'nodes': [], 'edges': []} def init_neo4j_connection(uri: str, username: str, password: str): """Initialize Neo4j connection""" try: driver = GraphDatabase.driver(uri, auth=(username, password)) driver.verify_connectivity() st.session_state.driver = driver return True, "Connected successfully!" except Exception as e: return False, f"Connection failed: {str(e)}" def init_kg_pipeline(openai_api_key: str, node_types: List[str], relationship_types: List[str], patterns: List[tuple]): """Initialize Knowledge Graph Pipeline""" try: # Import here to avoid issues if packages aren't installed from neo4j_graphrag.embeddings import OpenAIEmbeddings from neo4j_graphrag.experimental.pipeline.kg_builder import SimpleKGPipeline from neo4j_graphrag.llm import OpenAILLM # Create embedder embedder = OpenAIEmbeddings( model="text-embedding-3-large", api_key=openai_api_key ) # Create LLM llm = OpenAILLM( model_name="gpt-4o", model_params={ "max_tokens": 2000, "response_format": {"type": "json_object"}, "temperature": 0, }, api_key=openai_api_key ) # Create pipeline kg_builder = SimpleKGPipeline( llm=llm, driver=st.session_state.driver, embedder=embedder, schema={ "node_types": node_types, "relationship_types": relationship_types, "patterns": patterns, }, on_error="IGNORE", from_pdf=False, ) st.session_state.kg_builder = kg_builder return True, "Knowledge Graph Pipeline initialized successfully!" except Exception as e: return False, f"Pipeline initialization failed: {str(e)}" def process_csv_data(df: pd.DataFrame, text_columns: List[str]) -> str: """Convert CSV data to text for knowledge graph processing""" text_data = [] for _, row in df.iterrows(): row_text = [] for col in text_columns: if col in df.columns and pd.notna(row[col]): row_text.append(f"{col}: {row[col]}") if row_text: text_data.append(". ".join(row_text)) return "\n\n".join(text_data) def read_uploaded_file(uploaded_file) -> str: """Read content from uploaded file""" try: if uploaded_file.type == "text/plain": return str(uploaded_file.read(), "utf-8") elif uploaded_file.type == "text/csv": df = pd.read_csv(uploaded_file) # Convert all columns to text representation text_data = [] for _, row in df.iterrows(): row_items = [] for col in df.columns: if pd.notna(row[col]): row_items.append(f"{col}: {row[col]}") if row_items: text_data.append(". ".join(row_items)) return "\n\n".join(text_data) elif uploaded_file.type == "application/json": data = json.load(uploaded_file) return json.dumps(data, indent=2) else: return f"File type {uploaded_file.type} is not supported for text extraction." except Exception as e: st.error(f"Error reading file: {str(e)}") return "" async def build_knowledge_graph(text: str): """Build knowledge graph from text""" try: await st.session_state.kg_builder.run_async(text=text) return True, "Knowledge graph built successfully!" except Exception as e: return False, f"Failed to build knowledge graph: {str(e)}" def fetch_graph_data(): """Fetch graph data from Neo4j""" try: with st.session_state.driver.session() as session: # Fetch nodes nodes_query = """ MATCH (n) RETURN id(n) as id, labels(n) as labels, properties(n) as properties LIMIT 1000 """ nodes_result = session.run(nodes_query) nodes = [] for record in nodes_result: node = { 'id': record['id'], 'labels': record['labels'], 'properties': record['properties'] } nodes.append(node) # Fetch relationships edges_query = """ MATCH (a)-[r]->(b) RETURN id(a) as source, id(b) as target, type(r) as relationship, properties(r) as properties LIMIT 1000 """ edges_result = session.run(edges_query) edges = [] for record in edges_result: edge = { 'source': record['source'], 'target': record['target'], 'relationship': record['relationship'], 'properties': record['properties'] } edges.append(edge) st.session_state.graph_data = {'nodes': nodes, 'edges': edges} return True, f"Fetched {len(nodes)} nodes and {len(edges)} relationships" except Exception as e: return False, f"Failed to fetch graph data: {str(e)}" def visualize_graph(): """Create interactive graph visualization using Plotly""" if not st.session_state.graph_data['nodes']: st.warning("No graph data available. Please build a knowledge graph first.") return # Create NetworkX graph G = nx.Graph() # Add nodes for node in st.session_state.graph_data['nodes']: label = node['labels'][0] if node['labels'] else 'Unknown' name = node['properties'].get('name', f"Node_{node['id']}") G.add_node(node['id'], label=label, name=name, **node['properties']) # Add edges for edge in st.session_state.graph_data['edges']: G.add_edge(edge['source'], edge['target'], relationship=edge['relationship'], **edge['properties']) if len(G.nodes()) == 0: st.warning("No nodes to display") return # Generate layout try: pos = nx.spring_layout(G, k=1, iterations=50) except: pos = {node: (0, 0) for node in G.nodes()} # Create edge traces edge_x = [] edge_y = [] edge_info = [] for edge in G.edges(): x0, y0 = pos[edge[0]] x1, y1 = pos[edge[1]] edge_x.extend([x0, x1, None]) edge_y.extend([y0, y1, None]) relationship = G.edges[edge].get('relationship', 'CONNECTED_TO') edge_info.append(f"{edge[0]} --[{relationship}]--> {edge[1]}") edge_trace = go.Scatter( x=edge_x, y=edge_y, line=dict(width=2, color='#888'), hoverinfo='none', mode='lines' ) # Create node traces node_x = [] node_y = [] node_info = [] node_colors = [] node_labels = [] color_map = { 'Person': '#FF6B6B', 'House': '#4ECDC4', 'Planet': '#45B7D1', 'Organization': '#96CEB4', 'Location': '#FFEAA7' } for node in G.nodes(): x, y = pos[node] node_x.append(x) node_y.append(y) node_data = G.nodes[node] label = node_data.get('label', 'Unknown') name = node_data.get('name', f'Node_{node}') # Create hover info info_parts = [f"ID: {node}", f"Type: {label}", f"Name: {name}"] for key, value in node_data.items(): if key not in ['label', 'name'] and value: info_parts.append(f"{key}: {value}") node_info.append('
'.join(info_parts)) node_colors.append(color_map.get(label, '#DDA0DD')) node_labels.append(name) node_trace = go.Scatter( x=node_x, y=node_y, mode='markers+text', hoverinfo='text', text=node_labels, textposition="middle center", hovertext=node_info, marker=dict( size=30, color=node_colors, line=dict(width=2, color='white') ) ) # Create figure fig = go.Figure(data=[edge_trace, node_trace], layout=go.Layout( title='Knowledge Graph Visualization', titlefont_size=16, showlegend=False, hovermode='closest', margin=dict(b=20,l=5,r=5,t=40), annotations=[ dict( text="Hover over nodes for details", showarrow=False, xref="paper", yref="paper", x=0.005, y=-0.002, font=dict(color="#888888", size=12) )], xaxis=dict(showgrid=False, zeroline=False, showticklabels=False), yaxis=dict(showgrid=False, zeroline=False, showticklabels=False), height=600 )) st.plotly_chart(fig, use_container_width=True) # Display graph statistics col1, col2, col3, col4 = st.columns(4) with col1: st.metric("Nodes", len(G.nodes())) with col2: st.metric("Edges", len(G.edges())) with col3: st.metric("Node Types", len(set(G.nodes[node].get('label', 'Unknown') for node in G.nodes()))) with col4: density = nx.density(G) if len(G.nodes()) > 1 else 0 st.metric("Density", f"{density:.3f}") def query_graph(query: str): """Execute Cypher query on the graph""" try: with st.session_state.driver.session() as session: result = session.run(query) records = list(result) return True, records except Exception as e: return False, str(e) def main(): st.title("πŸ•ΈοΈ Neo4j Knowledge Graph Builder") st.markdown("Build, visualize, and query knowledge graphs from various file types") # Sidebar for configuration with st.sidebar: st.header("βš™οΈ Configuration") # Neo4j Connection st.subheader("Neo4j Database") neo4j_uri = st.text_input("Neo4j URI", value="neo4j://localhost:7687") neo4j_username = st.text_input("Username", value="neo4j") neo4j_password = st.text_input("Password", type="password", value="password") if st.button("Connect to Neo4j"): success, message = init_neo4j_connection(neo4j_uri, neo4j_username, neo4j_password) if success: st.success(message) else: st.error(message) # OpenAI Configuration st.subheader("OpenAI Settings") openai_api_key = st.text_input("OpenAI API Key", type="password") # Schema Configuration st.subheader("Knowledge Graph Schema") # Node Types node_types_input = st.text_area("Node Types (one per line)", value="Person\nHouse\nPlanet\nOrganization\nLocation") node_types = [nt.strip() for nt in node_types_input.split('\n') if nt.strip()] # Relationship Types rel_types_input = st.text_area("Relationship Types (one per line)", value="PARENT_OF\nHEIR_OF\nRULES\nWORKS_FOR\nLIVES_IN") relationship_types = [rt.strip() for rt in rel_types_input.split('\n') if rt.strip()] # Patterns st.subheader("Relationship Patterns") patterns_input = st.text_area("Patterns (format: Source,Relationship,Target - one per line)", value="Person,PARENT_OF,Person\nPerson,HEIR_OF,House\nHouse,RULES,Planet\nPerson,WORKS_FOR,Organization\nPerson,LIVES_IN,Location") patterns = [] for pattern_line in patterns_input.split('\n'): if pattern_line.strip(): parts = [p.strip() for p in pattern_line.split(',')] if len(parts) == 3: patterns.append(tuple(parts)) if st.button("Initialize KG Pipeline") and openai_api_key: if st.session_state.driver: success, message = init_kg_pipeline(openai_api_key, node_types, relationship_types, patterns) if success: st.success(message) else: st.error(message) else: st.error("Please connect to Neo4j first") # Main content tabs tab1, tab2, tab3, tab4 = st.tabs(["πŸ“„ Data Input", "πŸ•ΈοΈ Graph Visualization", "πŸ” Query Graph", "πŸ“Š Analytics"]) with tab1: st.header("Data Input & Knowledge Graph Creation") # File upload uploaded_files = st.file_uploader( "Upload files for knowledge graph creation", type=['txt', 'csv', 'json', 'pdf'], accept_multiple_files=True ) # Text input text_input = st.text_area("Or paste text directly:", height=200) # CSV specific options if uploaded_files: csv_files = [f for f in uploaded_files if f.type == "text/csv"] if csv_files: st.subheader("CSV Processing Options") for csv_file in csv_files: df = pd.read_csv(csv_file) st.write(f"**{csv_file.name}** - Shape: {df.shape}") st.write("Preview:") st.dataframe(df.head()) # Process button if st.button("πŸš€ Build Knowledge Graph", type="primary"): if not st.session_state.kg_builder: st.error("Please initialize the KG Pipeline first") else: all_text = "" # Process uploaded files if uploaded_files: for uploaded_file in uploaded_files: file_content = read_uploaded_file(uploaded_file) all_text += f"\n\n--- From {uploaded_file.name} ---\n{file_content}" # Add direct text input if text_input.strip(): all_text += f"\n\n--- Direct Input ---\n{text_input.strip()}" if all_text.strip(): with st.spinner("Building knowledge graph..."): try: # Run async function loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) success, message = loop.run_until_complete(build_knowledge_graph(all_text)) loop.close() if success: st.success(message) # Fetch updated graph data fetch_success, fetch_message = fetch_graph_data() if fetch_success: st.info(fetch_message) else: st.warning(f"Graph built but couldn't fetch data: {fetch_message}") else: st.error(message) except Exception as e: st.error(f"Error: {str(e)}") else: st.warning("Please provide some text or upload files to process") with tab2: st.header("Knowledge Graph Visualization") col1, col2 = st.columns([3, 1]) with col2: if st.button("πŸ”„ Refresh Graph Data"): if st.session_state.driver: success, message = fetch_graph_data() if success: st.success(message) else: st.error(message) else: st.error("Please connect to Neo4j first") with col1: st.markdown("### Interactive Graph") # Visualize the graph visualize_graph() # Display raw data if st.session_state.graph_data['nodes']: with st.expander("View Raw Graph Data"): col1, col2 = st.columns(2) with col1: st.subheader("Nodes") st.json(st.session_state.graph_data['nodes'][:5]) # Show first 5 with col2: st.subheader("Relationships") st.json(st.session_state.graph_data['edges'][:5]) # Show first 5 with tab3: st.header("Query Knowledge Graph") # Predefined queries st.subheader("Quick Queries") col1, col2, col3 = st.columns(3) with col1: if st.button("All Nodes"): query = "MATCH (n) RETURN n LIMIT 25" st.code(query, language="cypher") if st.session_state.driver: success, result = query_graph(query) if success: st.write(f"Found {len(result)} nodes") for record in result[:10]: # Show first 10 st.write(record['n']) with col2: if st.button("All Relationships"): query = "MATCH (a)-[r]->(b) RETURN a.name, type(r), b.name LIMIT 25" st.code(query, language="cypher") if st.session_state.driver: success, result = query_graph(query) if success: df = pd.DataFrame([dict(record) for record in result]) st.dataframe(df) with col3: if st.button("Node Statistics"): query = "MATCH (n) RETURN labels(n)[0] as type, count(*) as count ORDER BY count DESC" st.code(query, language="cypher") if st.session_state.driver: success, result = query_graph(query) if success: df = pd.DataFrame([dict(record) for record in result]) st.bar_chart(df.set_index('type')) # Custom query st.subheader("Custom Cypher Query") custom_query = st.text_area("Enter your Cypher query:", value="MATCH (n) RETURN n LIMIT 10", height=100) if st.button("Execute Query"): if st.session_state.driver and custom_query.strip(): success, result = query_graph(custom_query) if success: st.success(f"Query executed successfully! Found {len(result)} results") if result: # Try to convert to DataFrame for better display try: df = pd.DataFrame([dict(record) for record in result]) st.dataframe(df) except: # Fall back to raw display for i, record in enumerate(result[:20]): # Limit display st.write(f"**Result {i+1}:**") st.write(dict(record)) if i > 20: st.write(f"... and {len(result) - 20} more results") break else: st.error(f"Query failed: {result}") else: st.warning("Please connect to Neo4j and enter a query") with tab4: st.header("Graph Analytics") if st.session_state.graph_data['nodes']: # Basic statistics nodes = st.session_state.graph_data['nodes'] edges = st.session_state.graph_data['edges'] col1, col2, col3, col4 = st.columns(4) with col1: st.metric("Total Nodes", len(nodes)) with col2: st.metric("Total Edges", len(edges)) with col3: node_types = [n['labels'][0] if n['labels'] else 'Unknown' for n in nodes] st.metric("Unique Node Types", len(set(node_types))) with col4: rel_types = [e['relationship'] for e in edges] st.metric("Unique Relationships", len(set(rel_types))) # Node type distribution st.subheader("Node Type Distribution") node_type_counts = pd.Series(node_types).value_counts() fig = px.pie(values=node_type_counts.values, names=node_type_counts.index) st.plotly_chart(fig) # Relationship type distribution st.subheader("Relationship Type Distribution") rel_type_counts = pd.Series(rel_types).value_counts() fig = px.bar(x=rel_type_counts.index, y=rel_type_counts.values) fig.update_layout(xaxis_title="Relationship Type", yaxis_title="Count") st.plotly_chart(fig) else: st.info("No graph data available. Please build a knowledge graph first.") if __name__ == "__main__": main()