Spaces:
Running
Running
ping98k
Refactor heatmap event handling to update x-axis labels for clarity and ensure heatmap is plotted after processing search group logic.
048e22e
// Handles the heatmap event, group similarity logic, and text reordering for cluster visualization | |
import { getGroupEmbeddings, getLineEmbeddings } from './embedding.js'; | |
import { plotHeatmap } from './plotting.js'; | |
const task = "Given a textual input sentence, retrieve relevant categories that best describe it."; | |
// Cosine similarity between two vectors | |
function cosine(a, b) { | |
let dot = 0, na = 0, nb = 0; | |
for (let i = 0; i < a.length; i++) { | |
dot += a[i] * b[i]; | |
na += a[i] * a[i]; | |
nb += b[i] * b[i]; | |
} | |
return na && nb ? dot / Math.sqrt(na * nb) : 0; | |
} | |
// Remove group headers and split each group into an array of lines (excluding empty lines and headers) | |
function getCleanGroups(groups) { | |
return groups.map(g => | |
g.split("\n").filter(l => l && !l.startsWith("##")) | |
); | |
} | |
// Flatten all lines from all groups and get their embeddings | |
async function getAllLinesAndEmbeds(cleanGroups, task) { | |
const allLines = cleanGroups.flat(); | |
const allEmbeds = await getLineEmbeddings(allLines, task); | |
return { allLines, allEmbeds }; | |
} | |
// Build an index mapping for each group to map group-relative indices to global indices | |
function getIdxByGroup(cleanGroups) { | |
const idxByGroup = []; | |
let p = 0; | |
for (const g of cleanGroups) { | |
idxByGroup.push(Array.from({ length: g.length }, (_, i) => p + i)); | |
p += g.length; | |
} | |
return idxByGroup; | |
} | |
// Build the final output text for reordered groups, including headers and sorted lines | |
function buildFinalText(order, sortedLines, clusterNames, n) { | |
return order | |
.map((gIdx, i) => { | |
const header = | |
clusterNames?.length === n ? clusterNames[gIdx] : `Group ${i + 1}`; | |
return `## ${header}\n${sortedLines[i].join("\n")}`; | |
}) | |
.join("\n\n\n"); | |
} | |
export async function handleHeatmapEvent() { | |
const progressBar = document.getElementById("progress-bar"); | |
const progressBarInner = document.getElementById("progress-bar-inner"); | |
progressBar.style.display = "block"; | |
progressBarInner.style.width = "0%"; | |
const text = document.getElementById("input").value; | |
// Get search group from dedicated input (do not use ##search in main input) | |
const searchGroupText = document.getElementById("search-group-input")?.value.trim(); | |
// Get search sort mode from dropdown (either 'line' or 'group') | |
const searchSortMode = document.getElementById("search-sort-mode")?.value || "group"; | |
const search_by_max_search_line = searchSortMode === "line"; | |
const search_by_max_search_group = searchSortMode === "group"; | |
// Parse cluster names from main input (ignore any ##search) | |
const clusterNames = text.split(/\n/) | |
.map(x => x.trim()) | |
.filter(x => x && x.startsWith('##')) | |
.map(x => x.replace(/^##\s*/, '')); | |
const groups = text.split(/\n{3,}/); | |
// Get group embeddings (removes ## lines internally) | |
const groupEmbeddings = await getGroupEmbeddings(groups, task); | |
const n = groupEmbeddings.length; | |
progressBarInner.style.width = "30%"; | |
// Compute cosine similarity matrix between all group embeddings | |
const sim = []; | |
for (let i = 0; i < n; i++) { | |
const row = []; | |
for (let j = 0; j < n; j++) { | |
let dot = 0, na = 0, nb = 0; | |
for (let k = 0; k < groupEmbeddings[i].length; k++) { | |
dot += groupEmbeddings[i][k] * groupEmbeddings[j][k]; | |
na += groupEmbeddings[i][k] ** 2; | |
nb += groupEmbeddings[j][k] ** 2; | |
} | |
row.push(dot / Math.sqrt(na * nb)); | |
} | |
sim.push(row); | |
} | |
progressBarInner.style.width = "60%"; | |
// Always use all group indices in order | |
let order = Array.from({ length: n }, (_, i) => i); | |
// Only use search group if provided in search-group-input | |
let useSearchGroup = !!searchGroupText; | |
let searchIdx = -1; | |
let searchLines = []; | |
let searchEmbeds = []; | |
let refEmbed = null; | |
if (useSearchGroup) { | |
searchLines = searchGroupText.split(/\n/).map(l => l.trim()).filter(l => l); | |
if (searchLines.length > 0) { | |
searchEmbeds = await getLineEmbeddings(searchLines, task); | |
// For group similarity, use the mean embedding of the search group | |
refEmbed = searchEmbeds[0].map((_, i) => searchEmbeds.reduce((sum, e) => sum + e[i], 0) / searchEmbeds.length); | |
// Compute similarity to each group | |
const simToSearch = groupEmbeddings.map((emb, i) => ({ idx: i, sim: cosine(refEmbed, emb) })); | |
simToSearch.sort((a, b) => b.sim - a.sim); | |
order = [/* search group is not in groupEmbeddings, so just prepend -1 for heatmap */ -1, ...simToSearch.map(x => x.idx)]; | |
} | |
} | |
// Reorder sim matrix and clusterNames for heatmap visualization | |
let simOrdered, xLabels; | |
if (useSearchGroup && searchLines.length > 0) { | |
// Insert search group as first row/col in heatmap, with similarity 1 to itself and to other groups | |
simOrdered = [ | |
[1, ...order.slice(1).map(idx => idx === -1 ? 1 : cosine(refEmbed, groupEmbeddings[idx]))], | |
...order.slice(1).map(i => [cosine(refEmbed, groupEmbeddings[i]), ...order.slice(1).map(j => sim[i][j])]) | |
]; | |
xLabels = ["Search", ...order.slice(1).map(i => (clusterNames && clusterNames[i]) ? clusterNames[i] : `Group ${i + 1}`)]; | |
} else { | |
simOrdered = order.map(i => order.map(j => sim[i][j])); | |
xLabels = order.map(i => (clusterNames && clusterNames[i]) ? clusterNames[i] : `Group ${i + 1}`); | |
} | |
// If search group is provided and sorting by line, reorder lines in each group by similarity to search lines | |
if (useSearchGroup && search_by_max_search_line && searchEmbeds.length > 0) { | |
const cleanGroups = getCleanGroups(groups); | |
const { allLines, allEmbeds } = await getAllLinesAndEmbeds(cleanGroups, task); | |
const idxByGroup = getIdxByGroup(cleanGroups); | |
const score = e => Math.max(...searchEmbeds.map(se => cosine(se, e))); | |
// Skip -1 (search group) in order for main input reordering | |
const sorted = (order[0] === -1 ? order.slice(1) : order).map(g => | |
idxByGroup[g] | |
.map(i => ({ t: allLines[i], s: score(allEmbeds[i]) })) | |
.sort((a, b) => b.s - a.s) | |
.map(o => o.t) | |
); | |
const finalText = buildFinalText(order[0] === -1 ? order.slice(1) : order, sorted, clusterNames, n); | |
document.getElementById("input").value = finalText; | |
} | |
// If search group is provided and sorting by group, reorder lines in each group by similarity to the search group embedding | |
if (useSearchGroup && search_by_max_search_group && refEmbed) { | |
const cleanGroups = getCleanGroups(groups); | |
const { allLines, allEmbeds } = await getAllLinesAndEmbeds(cleanGroups, task); | |
const idxByGroup = getIdxByGroup(cleanGroups); | |
// Skip -1 (search group) in order for main input reordering | |
const sortedLines = (order[0] === -1 ? order.slice(1) : order).map(gIdx => | |
idxByGroup[gIdx] | |
.map(i => ({ t: allLines[i], s: cosine(refEmbed, allEmbeds[i]) })) | |
.sort((a, b) => b.s - a.s) | |
.map(o => o.t) | |
); | |
const finalText = buildFinalText(order[0] === -1 ? order.slice(1) : order, sortedLines, clusterNames, n); | |
document.getElementById("input").value = finalText; | |
} | |
plotHeatmap(simOrdered, xLabels, xLabels); | |
progressBarInner.style.width = "100%"; | |
} | |