Spaces:
Running
Running
import { | |
animation_duration, | |
chat, | |
cleanUpMessage, | |
event_types, | |
eventSource, | |
Generate, | |
getGeneratingApi, | |
is_send_press, | |
isStreamingEnabled, | |
substituteParamsExtended, | |
} from '../script.js'; | |
import { debounce, delay, getStringHash } from './utils.js'; | |
import { decodeTextTokens, getTokenizerBestMatch } from './tokenizers.js'; | |
import { power_user } from './power-user.js'; | |
import { callGenericPopup, POPUP_TYPE } from './popup.js'; | |
import { t } from './i18n.js'; | |
const TINTS = 4; | |
const MAX_MESSAGE_LOGPROBS = 100; | |
const REROLL_BUTTON = $('#logprobsReroll'); | |
/** | |
* Tuple of a candidate token and its logarithm of probability of being chosen | |
* @typedef {[string, number]} Candidate - (token, logprob) | |
*/ | |
/** | |
* @typedef {(Node|JQuery<Text>|JQuery<HTMLElement>)[]} NodeArray - Array of DOM nodes | |
*/ | |
/** | |
* Logprob data for a single message | |
* @typedef {Object} MessageLogprobData | |
* @property {number} created - timestamp of when the message was generated | |
* @property {number} hash - hash of the message object | |
* @property {number} messageId - ID of the source message | |
* @property {number} swipeId - ID of the source swipe on the source message | |
* @property {string} api - API used to generate the message | |
* @property {TokenLogprobs[]} messageLogprobs Logprob data for each token, by | |
* its index in the message | |
* @property {string | null} continueFrom - the 'continue' prefix used to | |
* generate the message, if any | |
*/ | |
/** | |
* Logprob data for a single token | |
* @typedef {Object} TokenLogprobs | |
* @property {string} token - A token generated by the model | |
* @property {Candidate[]} topLogprobs - Array of top candidate tokens | |
*/ | |
/** | |
* State object for Token Probabilities | |
* @typedef {Object} LogprobsState | |
* @property {?TokenLogprobs} selectedTokenLogprobs Log probabilities for | |
* currently-selected token. | |
* @property {Map<number, MessageLogprobData>} messageLogprobs Log probabilities for | |
* each message, keyed by message hash. | |
*/ | |
/** | |
* @type {LogprobsState} state | |
*/ | |
const state = { | |
selectedTokenLogprobs: null, | |
messageLogprobs: new Map(), | |
}; | |
/** | |
* Renders the Token Probabilities UI and all subviews with the active message's | |
* logprobs data. If the message has no token logprobs, a message is displayed. | |
*/ | |
function renderAlternativeTokensView() { | |
const view = $('#logprobs_generation_output'); | |
if (!view.is(':visible')) { | |
return; | |
} | |
view.empty(); | |
state.selectedTokenLogprobs = null; | |
renderTopLogprobs(); | |
const { messageLogprobs, continueFrom } = getActiveMessageLogprobData() || {}; | |
const usingSmoothStreaming = isStreamingEnabled() && power_user.smooth_streaming; | |
if (!messageLogprobs?.length || usingSmoothStreaming) { | |
const emptyState = $('<div></div>'); | |
const noTokensMsg = !power_user.request_token_probabilities | |
? '<span>Enable <b>Request token probabilities</b> in the User Settings menu to use this feature.</span>' | |
: usingSmoothStreaming | |
? t`Token probabilities are not available when using Smooth Streaming.` | |
: is_send_press | |
? t`Generation in progress...` | |
: t`No token probabilities available for the current message.`; | |
emptyState.html(noTokensMsg); | |
emptyState.addClass('logprobs_empty_state'); | |
view.append(emptyState); | |
return; | |
} | |
const prefix = continueFrom || ''; | |
const tokenSpans = []; | |
REROLL_BUTTON.toggle(!!prefix); | |
if (prefix) { | |
REROLL_BUTTON.off('click').on('click', () => onPrefixClicked(prefix.length)); | |
let cumulativeOffset = 0; | |
const words = prefix.split(/\s+/); | |
const delimiters = prefix.match(/\s+/g) || []; // Capture the actual delimiters | |
words.forEach((word, i) => { | |
const span = $('<span></span>'); | |
span.text(`${word} `); | |
span.addClass('logprobs_output_prefix'); | |
span.attr('title', t`Reroll from this point`); | |
let offset = cumulativeOffset; | |
span.on('click', () => onPrefixClicked(offset)); | |
addKeyboardProps(span); | |
tokenSpans.push(span); | |
tokenSpans.push(delimiters[i]?.includes('\n') | |
? document.createElement('br') | |
: document.createTextNode(delimiters[i] || ' '), | |
); | |
cumulativeOffset += word.length + (delimiters[i]?.length || 0); | |
}); | |
} | |
messageLogprobs.forEach((tokenData, i) => { | |
const { token } = tokenData; | |
const span = $('<span></span>'); | |
const text = toVisibleWhitespace(token); | |
span.text(text); | |
span.addClass('logprobs_output_token'); | |
span.addClass('logprobs_tint_' + (i % TINTS)); | |
span.on('click', () => onSelectedTokenChanged(tokenData, span)); | |
addKeyboardProps(span); | |
tokenSpans.push(...withVirtualWhitespace(token, span)); | |
}); | |
view.append(tokenSpans); | |
// scroll past long prior context | |
if (prefix) { | |
const element = view.find('.logprobs_output_token').first(); | |
const scrollOffset = element.offset().top - element.parent().offset().top; | |
element.parent().scrollTop(scrollOffset); | |
} | |
} | |
function addKeyboardProps(element) { | |
element.attr('role', 'button'); | |
element.attr('tabindex', '0'); | |
element.keydown(function (e) { | |
if (e.key === 'Enter' || e.key === ' ') { | |
element.click(); | |
} | |
}); | |
} | |
/** | |
* renderTopLogprobs renders the top logprobs subview with the currently | |
* selected token highlighted. If no token is selected, the subview is hidden. | |
* | |
* Callers: | |
* - renderAlternativeTokensView, to render the entire view | |
* - onSelectedTokenChanged, to update the view when a token is selected | |
*/ | |
function renderTopLogprobs() { | |
$('#logprobs_top_logprobs_hint').hide(); | |
const view = $('.logprobs_candidate_list'); | |
view.empty(); | |
if (!state.selectedTokenLogprobs) { | |
return; | |
} | |
const { token: selectedToken, topLogprobs } = state.selectedTokenLogprobs; | |
let sum = 0; | |
const nodes = []; | |
const candidates = topLogprobs | |
.sort(([, logA], [, logB]) => logB - logA) | |
.map(([text, log]) => { | |
if (log <= 0) { | |
const probability = Math.exp(log); | |
sum += probability; | |
return [text, probability, log]; | |
} else { | |
return [text, log, null]; | |
} | |
}); | |
candidates.push(['<others>', 1 - sum, 0]); | |
let matched = false; | |
for (const [token, probability, log] of candidates) { | |
const container = $('<button class="flex-container flexFlowColumn logprobs_top_candidate"></button>'); | |
const tokenNormalized = String(token).replace(/^[▁Ġ]/g, ' '); | |
if (token === selectedToken || tokenNormalized === selectedToken) { | |
matched = true; | |
container.addClass('selected'); | |
} | |
const tokenText = $('<span></span>').text(`${toVisibleWhitespace(token.toString())}`); | |
const percentText = $('<span></span>').text(`${(+probability * 100).toFixed(2)}%`); | |
container.append(tokenText, percentText); | |
if (log) { | |
container.attr('title', `logarithm: ${log}`); | |
} | |
addKeyboardProps(container); | |
if (token !== '<others>') { | |
container.on('click', () => onAlternativeClicked(state.selectedTokenLogprobs, token.toString())); | |
} else { | |
container.prop('disabled', true); | |
} | |
nodes.push(container); | |
} | |
// Highlight the <others> node if the selected token was not included in the | |
// top logprobs | |
if (!matched) { | |
nodes[nodes.length - 1].css('background-color', 'rgba(255, 0, 0, 0.1)'); | |
} | |
view.append(nodes); | |
} | |
/** | |
* User clicks on a token in the token output view. It updates the selected token state | |
* and re-renders the top logprobs view, or deselects the token if it was already selected. | |
* @param {TokenLogprobs} logprobs - logprob data for the selected token | |
* @param {Node|JQuery} span - target span node that was clicked | |
*/ | |
function onSelectedTokenChanged(logprobs, span) { | |
$('.logprobs_output_token.selected').removeClass('selected'); | |
if (state.selectedTokenLogprobs === logprobs) { | |
state.selectedTokenLogprobs = null; | |
} else { | |
state.selectedTokenLogprobs = logprobs; | |
$(span).addClass('selected'); | |
} | |
renderTopLogprobs(); | |
} | |
/** | |
* onAlternativeClicked is called when the user clicks on an alternative token | |
* in the top logprobs view. It will create a new swipe message and prefill it | |
* with all text up to the selected token, followed by the chosen alternative. | |
* Then it requests a `continue` completion from the model with the new prompt. | |
* @param {TokenLogprobs} tokenLogprobs - logprob data for selected alternative | |
* @param {string} alternative - selected alternative token's text | |
*/ | |
function onAlternativeClicked(tokenLogprobs, alternative) { | |
if (!checkGenerateReady()) { | |
return; | |
} | |
if (getGeneratingApi() === 'openai') { | |
const title = t`Feature unavailable`; | |
const message = t`Due to API limitations, rerolling a token is not supported with OpenAI. Try switching to a different API.`; | |
const content = `<h3>${title}</h3><p>${message}</p>`; | |
return callGenericPopup(content, POPUP_TYPE.TEXT); | |
} | |
const { messageLogprobs, continueFrom } = getActiveMessageLogprobData(); | |
const replaceIndex = messageLogprobs.findIndex(x => x === tokenLogprobs); | |
const tokens = messageLogprobs.slice(0, replaceIndex + 1).map(({ token }) => token); | |
tokens[replaceIndex] = String(alternative).replace(/^[▁Ġ]/g, ' ').replace(/Ċ/g, '\n'); | |
const prefix = continueFrom || ''; | |
const prompt = prefix + tokens.join(''); | |
addGeneration(prompt); | |
} | |
/** | |
* User clicks on the reroll button in the token output view, or on a word in the | |
* prefix. Retrieve the prefix for the current message and truncate it at the | |
* offset for the selected word. Then request a `continue` completion from the | |
* model with the new prompt. | |
* | |
* If no offset is provided, the entire prefix will be rerolled. | |
* | |
* @param {number} offset - index of the token in the prefix to reroll from | |
* @returns {void} | |
* @param offset | |
*/ | |
function onPrefixClicked(offset = undefined) { | |
if (!checkGenerateReady()) { | |
return; | |
} | |
const { continueFrom } = getActiveMessageLogprobData() || {}; | |
const prefix = continueFrom ? continueFrom.substring(0, offset) : ''; | |
addGeneration(prefix); | |
} | |
function checkGenerateReady() { | |
if (is_send_press) { | |
toastr.warning('Please wait for the current generation to complete.'); | |
return false; | |
} | |
return true; | |
} | |
/** | |
* Generates a new swipe as a continuation of the given prompt, when user selects | |
* an alternative token or rerolls from a prefix. | |
* | |
* @param prompt | |
*/ | |
function addGeneration(prompt) { | |
const messageId = chat.length - 1; | |
if (prompt && prompt.length > 0) { | |
createSwipe(messageId, prompt); | |
$('.swipe_right:last').trigger('click'); | |
void Generate('continue'); | |
} else { | |
$('.swipe_right:last').trigger('click'); | |
} | |
} | |
/** | |
* onToggleLogprobsPanel is called when the user performs an action that toggles | |
* the logprobs view, such as clicking the Token Probabilities menu item or the | |
* close button. | |
*/ | |
function onToggleLogprobsPanel() { | |
const logprobsViewer = $('#logprobsViewer'); | |
// largely copied from CFGScale toggle | |
if (logprobsViewer.css('display') === 'none') { | |
logprobsViewer.addClass('resizing'); | |
logprobsViewer.css('display', 'flex'); | |
logprobsViewer.css('opacity', 0.0); | |
renderAlternativeTokensView(); | |
logprobsViewer.transition({ | |
opacity: 1.0, | |
duration: animation_duration, | |
}, async function () { | |
await delay(50); | |
logprobsViewer.removeClass('resizing'); | |
}); | |
} else { | |
logprobsViewer.addClass('resizing'); | |
logprobsViewer.transition({ | |
opacity: 0.0, | |
duration: animation_duration, | |
}, | |
async function () { | |
await delay(50); | |
logprobsViewer.removeClass('resizing'); | |
}); | |
setTimeout(function () { | |
logprobsViewer.hide(); | |
}, animation_duration); | |
} | |
} | |
/** | |
* Appends a new swipe to the target chat message with the given text. | |
* @param {number} messageId - target chat message ID | |
* @param {string} prompt - initial prompt text which will be continued | |
*/ | |
function createSwipe(messageId, prompt) { | |
// need to call `cleanUpMessage` on our new prompt, because we were working | |
// with raw model output and our new prompt is missing trimming/macro replacements | |
let cleanedPrompt = cleanUpMessage({ | |
getMessage: prompt, | |
isImpersonate: false, | |
isContinue: false, | |
displayIncompleteSentences: true, | |
}); | |
const msg = chat[messageId]; | |
const reasoningPrefix = substituteParamsExtended(power_user.reasoning.prefix); | |
const reasoningSuffix = substituteParamsExtended(power_user.reasoning.suffix); | |
const isReasoningAutoParsed = power_user.reasoning.auto_parse; | |
const msgHasParsedReasoning = msg.extra?.reasoning?.length > 0; | |
let shouldRerollReasoning = false; | |
//if we have pre-existing reasoning and are currently autoparsing | |
if (isReasoningAutoParsed && msgHasParsedReasoning) { | |
console.debug('saw autoparse on with reasoning in message'); | |
//but the reroll prompt does not include the end of reasoning | |
if (cleanedPrompt.includes(reasoningPrefix) && !cleanedPrompt.includes(reasoningSuffix)) { | |
//we need to send the results to the reasoning block | |
//this will involve the ReasoningHandler from reasoning.js | |
console.debug('..with start tag but no end tag... reroll reasoning'); | |
shouldRerollReasoning = true; | |
} | |
let hasReasoningPrefix = cleanedPrompt.includes(reasoningPrefix); | |
let hasReasoningSuffix = cleanedPrompt.includes(reasoningSuffix); | |
//..with both the start and end think tags | |
//OR | |
//..with only the end think tag (implying prefilled think start) | |
if (hasReasoningPrefix && hasReasoningSuffix) { | |
//we need to send the results to the response block without reasoning attached | |
console.debug('...incl. end tag...rerolling response'); | |
const endOfThink = cleanedPrompt.indexOf(reasoningSuffix) + reasoningSuffix.length; | |
cleanedPrompt = cleanedPrompt.substring(endOfThink); | |
} | |
//if cleanedprompt includes the think prefix, but no suffix.. | |
if (hasReasoningPrefix && !hasReasoningSuffix) { | |
console.debug('..no end tag...rerolling reasoning, so removing prefix'); | |
cleanedPrompt = cleanedPrompt.replace(reasoningPrefix, ''); | |
} | |
} | |
console.debug('cleanedPrompt: ', cleanedPrompt); | |
const newSwipeInfo = { | |
send_date: msg.send_date, | |
gen_started: msg.gen_started, | |
gen_finished: msg.gen_finished, | |
extra: { ...structuredClone(msg.extra), from_logprobs: new Date().getTime() }, | |
}; | |
msg.swipes = msg.swipes || []; | |
msg.swipe_info = msg.swipe_info || []; | |
// Add our new swipe, then make sure the active swipe is the one just before | |
// it. The call to `swipe_right` in addGeneration() will switch to it immediately. | |
//if we determined that we need to reroll from reasoning | |
if (shouldRerollReasoning) { | |
//cleaned prompt goes into reasoning | |
newSwipeInfo.extra.reasoning = cleanedPrompt; | |
//mes_text becomes empty, causing the reasoning handler to parse the reasoning first | |
msg.swipes.push(''); | |
} else { | |
//otherwise just add the cleaned prompt to the message and continue | |
msg.swipes.push(cleanedPrompt); | |
} | |
msg.swipe_info.push(newSwipeInfo); | |
msg.swipe_id = Math.max(0, msg.swipes.length - 2); | |
} | |
/** | |
* toVisibleWhitespace receives input text and replaces spaces with · and | |
* newlines with ↵. | |
* @param {string} input | |
* @returns {string} | |
*/ | |
function toVisibleWhitespace(input) { | |
return input.replace(/ /g, '·').replace(/[▁Ġ]/g, '·').replace(/[Ċ\n]/g, '↵'); | |
} | |
/** | |
* withVirtualWhitespace inserts line breaks and a zero-width space before and | |
* after the span node if its token begins or ends with whitespace in order to | |
* allow text to wrap despite whitespace characters being replaced with a dot. | |
* @param {string} text - token text being evaluated for whitespace | |
* @param {Node|JQuery} span - target span node to be wrapped | |
* @returns {NodeArray} - array of nodes to be appended to the parent element | |
*/ | |
function withVirtualWhitespace(text, span) { | |
/** @type {NodeArray} */ | |
const result = [span]; | |
if (text.match(/^\s/)) { | |
result.unshift(document.createTextNode('\u200b')); | |
} | |
if (text.match(/\s$/)) { | |
result.push($(document.createTextNode('\u200b'))); | |
} | |
if (text.match(/^[▁Ġ]/)) { | |
result.unshift(document.createTextNode('\u200b')); | |
} | |
// line breaks are trickier. we don't currently handle consecutive line | |
// breaks or line breaks occuring in between non-whitespace characters, but | |
// tokenizers generally don't produce those anyway. | |
// matches leading line break, at least one character, and trailing line break | |
if (text.match(/^\n(?:.|\n)+\n$/)) { | |
result.unshift($('<br>')); | |
result.push($('<br>')); | |
} else if (text.match(/^\n/)) { | |
result.unshift($('<br>')); | |
} else if (text.match(/\n$/)) { | |
result.push($('<br>')); | |
} | |
return result; | |
} | |
/** | |
* Receives the top logprobs for each token in a message and associates it with the active message. | |
* | |
* Ensure the active message has been updated and rendered before calling this function | |
* or the logprobs data will be saved to the wrong message. | |
* | |
* Callers: | |
* - Generate:onSuccess via saveLogprobsForActiveMessage, for non-streaming text completion | |
* - StreamingProcessor:onFinishStreaming, for streaming text completion | |
* - sendOpenAIRequest, for non-streaming chat completion | |
* | |
* @param {TokenLogprobs[]} logprobs - array of logprobs data for each token | |
* @param {string | null} continueFrom - for 'continue' generations, the prompt | |
*/ | |
export function saveLogprobsForActiveMessage(logprobs, continueFrom) { | |
if (!logprobs) { | |
// non-streaming APIs could return null data | |
return; | |
} | |
// NovelAI only returns token IDs in logprobs data; convert to text tokens in-place | |
if (getGeneratingApi() === 'novel') { | |
convertTokenIdLogprobsToText(logprobs); | |
} | |
const msgId = chat.length - 1; | |
/** @type {MessageLogprobData} */ | |
const data = { | |
created: new Date().getTime(), | |
api: getGeneratingApi(), | |
messageId: msgId, | |
swipeId: chat[msgId].swipe_id, | |
messageLogprobs: logprobs, | |
continueFrom, | |
hash: getMessageHash(chat[msgId]), | |
}; | |
state.messageLogprobs.set(data.hash, data); | |
// Clean up old logprobs data | |
const oldLogprobs = Array.from(state.messageLogprobs.values()) | |
.sort((a, b) => b.created - a.created) | |
.slice(MAX_MESSAGE_LOGPROBS); | |
for (const oldData of oldLogprobs) { | |
state.messageLogprobs.delete(oldData.hash); | |
} | |
} | |
function getMessageHash(message) { | |
// We don't use the swipe ID as a hash component because it's not stable, | |
// deleting a swipe will change the ID of all subsequent swipes. | |
const hashParams = { | |
name: message.name, | |
mid: chat.indexOf(message), | |
text: message.mes, | |
}; | |
return getStringHash(JSON.stringify(hashParams)); | |
} | |
/** | |
* getActiveMessageLogprobData returns the logprobs data for the active chat | |
* message. | |
* @returns {MessageLogprobData|null} | |
*/ | |
function getActiveMessageLogprobData() { | |
if (chat.length === 0) { | |
return null; | |
} | |
const hash = getMessageHash(chat[chat.length - 1]); | |
return state.messageLogprobs.get(hash) || null; | |
} | |
/** | |
* convertLogprobTokenIdsToText replaces token IDs in logprobs data with text tokens, | |
* for APIs that return token IDs instead of text tokens, to wit: NovelAI. | |
* | |
* @param {TokenLogprobs[]} input - logprobs data with numeric token IDs | |
*/ | |
function convertTokenIdLogprobsToText(input) { | |
const api = getGeneratingApi(); | |
if (api !== 'novel') { | |
// should have been checked by the caller | |
throw new Error('convertTokenIdLogprobsToText should only be called for NovelAI'); | |
} | |
const tokenizerId = getTokenizerBestMatch(api); | |
/** @type {any[]} Flatten unique token IDs across all logprobs */ | |
const tokenIds = Array.from(new Set(input.flatMap(logprobs => | |
logprobs.topLogprobs.map(([token]) => token).concat(logprobs.token), | |
))); | |
// Submit token IDs to tokenizer to get token text, then build ID->text map | |
// noinspection JSCheckFunctionSignatures - mutates input in-place | |
const { chunks } = decodeTextTokens(tokenizerId, tokenIds); | |
const tokenIdText = new Map(tokenIds.map((id, i) => [id, chunks[i]])); | |
// Fixup logprobs data with token text | |
input.forEach(logprobs => { | |
logprobs.token = tokenIdText.get(logprobs.token); | |
logprobs.topLogprobs = logprobs.topLogprobs.map(([token, logprob]) => | |
[tokenIdText.get(token), logprob], | |
); | |
}); | |
} | |
export function initLogprobs() { | |
REROLL_BUTTON.hide(); | |
const debouncedRender = debounce(renderAlternativeTokensView); | |
$('#logprobsViewerClose').on('click', onToggleLogprobsPanel); | |
$('#option_toggle_logprobs').on('click', onToggleLogprobsPanel); | |
eventSource.on(event_types.CHAT_CHANGED, debouncedRender); | |
eventSource.on(event_types.CHARACTER_MESSAGE_RENDERED, debouncedRender); | |
eventSource.on(event_types.IMPERSONATE_READY, debouncedRender); | |
eventSource.on(event_types.MESSAGE_DELETED, debouncedRender); | |
eventSource.on(event_types.MESSAGE_EDITED, debouncedRender); | |
eventSource.on(event_types.MESSAGE_SWIPED, debouncedRender); | |
} | |