Embedding-Playground / heatmap_event.js
ping98k
Refactor heatmap event handling to update x-axis labels for clarity and ensure heatmap is plotted after processing search group logic.
048e22e
// Handles the heatmap event, group similarity logic, and text reordering for cluster visualization
import { getGroupEmbeddings, getLineEmbeddings } from './embedding.js';
import { plotHeatmap } from './plotting.js';
const task = "Given a textual input sentence, retrieve relevant categories that best describe it.";
// Cosine similarity between two vectors
function cosine(a, b) {
let dot = 0, na = 0, nb = 0;
for (let i = 0; i < a.length; i++) {
dot += a[i] * b[i];
na += a[i] * a[i];
nb += b[i] * b[i];
}
return na && nb ? dot / Math.sqrt(na * nb) : 0;
}
// Remove group headers and split each group into an array of lines (excluding empty lines and headers)
function getCleanGroups(groups) {
return groups.map(g =>
g.split("\n").filter(l => l && !l.startsWith("##"))
);
}
// Flatten all lines from all groups and get their embeddings
async function getAllLinesAndEmbeds(cleanGroups, task) {
const allLines = cleanGroups.flat();
const allEmbeds = await getLineEmbeddings(allLines, task);
return { allLines, allEmbeds };
}
// Build an index mapping for each group to map group-relative indices to global indices
function getIdxByGroup(cleanGroups) {
const idxByGroup = [];
let p = 0;
for (const g of cleanGroups) {
idxByGroup.push(Array.from({ length: g.length }, (_, i) => p + i));
p += g.length;
}
return idxByGroup;
}
// Build the final output text for reordered groups, including headers and sorted lines
function buildFinalText(order, sortedLines, clusterNames, n) {
return order
.map((gIdx, i) => {
const header =
clusterNames?.length === n ? clusterNames[gIdx] : `Group ${i + 1}`;
return `## ${header}\n${sortedLines[i].join("\n")}`;
})
.join("\n\n\n");
}
export async function handleHeatmapEvent() {
const progressBar = document.getElementById("progress-bar");
const progressBarInner = document.getElementById("progress-bar-inner");
progressBar.style.display = "block";
progressBarInner.style.width = "0%";
const text = document.getElementById("input").value;
// Get search group from dedicated input (do not use ##search in main input)
const searchGroupText = document.getElementById("search-group-input")?.value.trim();
// Get search sort mode from dropdown (either 'line' or 'group')
const searchSortMode = document.getElementById("search-sort-mode")?.value || "group";
const search_by_max_search_line = searchSortMode === "line";
const search_by_max_search_group = searchSortMode === "group";
// Parse cluster names from main input (ignore any ##search)
const clusterNames = text.split(/\n/)
.map(x => x.trim())
.filter(x => x && x.startsWith('##'))
.map(x => x.replace(/^##\s*/, ''));
const groups = text.split(/\n{3,}/);
// Get group embeddings (removes ## lines internally)
const groupEmbeddings = await getGroupEmbeddings(groups, task);
const n = groupEmbeddings.length;
progressBarInner.style.width = "30%";
// Compute cosine similarity matrix between all group embeddings
const sim = [];
for (let i = 0; i < n; i++) {
const row = [];
for (let j = 0; j < n; j++) {
let dot = 0, na = 0, nb = 0;
for (let k = 0; k < groupEmbeddings[i].length; k++) {
dot += groupEmbeddings[i][k] * groupEmbeddings[j][k];
na += groupEmbeddings[i][k] ** 2;
nb += groupEmbeddings[j][k] ** 2;
}
row.push(dot / Math.sqrt(na * nb));
}
sim.push(row);
}
progressBarInner.style.width = "60%";
// Always use all group indices in order
let order = Array.from({ length: n }, (_, i) => i);
// Only use search group if provided in search-group-input
let useSearchGroup = !!searchGroupText;
let searchIdx = -1;
let searchLines = [];
let searchEmbeds = [];
let refEmbed = null;
if (useSearchGroup) {
searchLines = searchGroupText.split(/\n/).map(l => l.trim()).filter(l => l);
if (searchLines.length > 0) {
searchEmbeds = await getLineEmbeddings(searchLines, task);
// For group similarity, use the mean embedding of the search group
refEmbed = searchEmbeds[0].map((_, i) => searchEmbeds.reduce((sum, e) => sum + e[i], 0) / searchEmbeds.length);
// Compute similarity to each group
const simToSearch = groupEmbeddings.map((emb, i) => ({ idx: i, sim: cosine(refEmbed, emb) }));
simToSearch.sort((a, b) => b.sim - a.sim);
order = [/* search group is not in groupEmbeddings, so just prepend -1 for heatmap */ -1, ...simToSearch.map(x => x.idx)];
}
}
// Reorder sim matrix and clusterNames for heatmap visualization
let simOrdered, xLabels;
if (useSearchGroup && searchLines.length > 0) {
// Insert search group as first row/col in heatmap, with similarity 1 to itself and to other groups
simOrdered = [
[1, ...order.slice(1).map(idx => idx === -1 ? 1 : cosine(refEmbed, groupEmbeddings[idx]))],
...order.slice(1).map(i => [cosine(refEmbed, groupEmbeddings[i]), ...order.slice(1).map(j => sim[i][j])])
];
xLabels = ["Search", ...order.slice(1).map(i => (clusterNames && clusterNames[i]) ? clusterNames[i] : `Group ${i + 1}`)];
} else {
simOrdered = order.map(i => order.map(j => sim[i][j]));
xLabels = order.map(i => (clusterNames && clusterNames[i]) ? clusterNames[i] : `Group ${i + 1}`);
}
// If search group is provided and sorting by line, reorder lines in each group by similarity to search lines
if (useSearchGroup && search_by_max_search_line && searchEmbeds.length > 0) {
const cleanGroups = getCleanGroups(groups);
const { allLines, allEmbeds } = await getAllLinesAndEmbeds(cleanGroups, task);
const idxByGroup = getIdxByGroup(cleanGroups);
const score = e => Math.max(...searchEmbeds.map(se => cosine(se, e)));
// Skip -1 (search group) in order for main input reordering
const sorted = (order[0] === -1 ? order.slice(1) : order).map(g =>
idxByGroup[g]
.map(i => ({ t: allLines[i], s: score(allEmbeds[i]) }))
.sort((a, b) => b.s - a.s)
.map(o => o.t)
);
const finalText = buildFinalText(order[0] === -1 ? order.slice(1) : order, sorted, clusterNames, n);
document.getElementById("input").value = finalText;
}
// If search group is provided and sorting by group, reorder lines in each group by similarity to the search group embedding
if (useSearchGroup && search_by_max_search_group && refEmbed) {
const cleanGroups = getCleanGroups(groups);
const { allLines, allEmbeds } = await getAllLinesAndEmbeds(cleanGroups, task);
const idxByGroup = getIdxByGroup(cleanGroups);
// Skip -1 (search group) in order for main input reordering
const sortedLines = (order[0] === -1 ? order.slice(1) : order).map(gIdx =>
idxByGroup[gIdx]
.map(i => ({ t: allLines[i], s: cosine(refEmbed, allEmbeds[i]) }))
.sort((a, b) => b.s - a.s)
.map(o => o.t)
);
const finalText = buildFinalText(order[0] === -1 ? order.slice(1) : order, sortedLines, clusterNames, n);
document.getElementById("input").value = finalText;
}
plotHeatmap(simOrdered, xLabels, xLabels);
progressBarInner.style.width = "100%";
}