From 9b3be5d52ac1acff25a33a28a341cda90836657e Mon Sep 17 00:00:00 2001 From: Index Date: Mon, 22 Dec 2025 16:23:37 -0600 Subject: [PATCH] feat: add citations when response uses Google Search tool --- src/handlers/messages.ts | 77 +++++++++++++++++++++++++++++++++++---- src/model/prompt.txt | 2 +- src/utils/conversation.ts | 38 ++++++++++++++++--- 3 files changed, 103 insertions(+), 14 deletions(-) diff --git a/src/handlers/messages.ts b/src/handlers/messages.ts index 71546d7..3464e60 100644 --- a/src/handlers/messages.ts +++ b/src/handlers/messages.ts @@ -1,5 +1,5 @@ import modelPrompt from "../model/prompt.txt"; -import { ChatMessage, Conversation } from "@skyware/bot"; +import { ChatMessage, Conversation, RichText } from "@skyware/bot"; import * as c from "../core"; import * as tools from "../tools"; import consola from "consola"; @@ -37,7 +37,7 @@ async function generateAIResponse(parsedContext: string, messages: { parts: [ { text: modelPrompt - .replace("{{ handle }}", env.HANDLE), + .replace("$handle", env.HANDLE), }, ], }, @@ -102,6 +102,66 @@ async function generateAIResponse(parsedContext: string, messages: { return inference; } +function addCitations( + inference: Awaited>, +) { + let originalText = inference.text ?? ""; + if (!inference.candidates) { + return originalText; + } + const supports = inference.candidates[0]?.groundingMetadata + ?.groundingSupports; + const chunks = inference.candidates[0]?.groundingMetadata?.groundingChunks; + + const richText = new RichText(); + + if (!supports || !chunks || originalText === "") { + return richText.addText(originalText); + } + + const sortedSupports = [...supports].sort( + (a, b) => (b.segment?.endIndex ?? 0) - (a.segment?.endIndex ?? 0), + ); + + let currentText = originalText; + + for (const support of sortedSupports) { + const endIndex = support.segment?.endIndex; + if (endIndex === undefined || !support.groundingChunkIndices?.length) { + continue; + } + + const citationLinks = support.groundingChunkIndices + .map((i) => { + const uri = chunks[i]?.web?.uri; + if (uri) { + return { index: i + 1, uri }; + } + return null; + }) + .filter(Boolean); + + if (citationLinks.length > 0) { + richText.addText(currentText.slice(endIndex)); + + citationLinks.forEach((citation, idx) => { + if (citation) { + richText.addLink(`[${citation.index}]`, citation.uri); + if (idx < citationLinks.length - 1) { + richText.addText(", "); + } + } + }); + + currentText = currentText.slice(0, endIndex); + } + } + + richText.addText(currentText); + + return richText; +} + export async function handler(message: ChatMessage): Promise { const conversation = await message.getConversation(); // ? Conversation should always be able to be found, but just in case: @@ -162,16 +222,17 @@ export async function handler(message: ChatMessage): Promise { } const responseText = inference.text; + const responseWithCitations = addCitations(inference); - if (responseText) { - logger.success("Generated text:", inference.text); - saveMessage(conversation, env.DID, inference.text!); + if (responseWithCitations) { + logger.success("Generated text:", responseText); + saveMessage(conversation, env.DID, responseText!); - if (exceedsGraphemes(responseText)) { - multipartResponse(conversation, responseText); + if (exceedsGraphemes(responseWithCitations)) { + multipartResponse(conversation, responseWithCitations); } else { conversation.sendMessage({ - text: responseText, + text: responseWithCitations, }); } } diff --git a/src/model/prompt.txt b/src/model/prompt.txt index 8c75bcf..a9e2454 100644 --- a/src/model/prompt.txt +++ b/src/model/prompt.txt @@ -1,7 +1,7 @@ You are Aero, a neutral and helpful assistant on Bluesky. Your job is to give clear, factual, and concise explanations or context about posts users send you. -Handle: {{ handle }} +Handle: $handle Guidelines: diff --git a/src/utils/conversation.ts b/src/utils/conversation.ts index b817ff7..1b36d43 100644 --- a/src/utils/conversation.ts +++ b/src/utils/conversation.ts @@ -2,6 +2,7 @@ import { type ChatMessage, type Conversation, graphemeLength, + RichText, } from "@skyware/bot"; import * as yaml from "js-yaml"; import db from "../db"; @@ -61,7 +62,11 @@ async function initConvo(convo: Conversation, initialMessage: ChatMessage) { did: user.did, postUri, revision: _convo.revision, - text: initialMessage.text, + text: + !initialMessage.text || + initialMessage.text.trim().length == 0 + ? "Explain this post." + : initialMessage.text, }); return _convo!; @@ -110,7 +115,11 @@ export async function parseConversation( did: getUserDid(convo).did, postUri: row.postUri, revision: row.revision, - text: latestMessage!.text, + text: postUri && + (!latestMessage.text || + latestMessage.text.trim().length == 0) + ? "Explain this post." + : latestMessage.text, }); } @@ -196,7 +205,10 @@ export async function saveMessage( /* Reponse Utilities */ -export function exceedsGraphemes(content: string) { +export function exceedsGraphemes(content: string | RichText) { + if (content instanceof RichText) { + return graphemeLength(content.text) > MAX_GRAPHEMES; + } return graphemeLength(content) > MAX_GRAPHEMES; } @@ -224,8 +236,24 @@ export function splitResponse(text: string): string[] { return chunks.map((chunk, i) => `(${i + 1}/${total}) ${chunk}`); } -export async function multipartResponse(convo: Conversation, content: string) { - const parts = splitResponse(content).filter((p) => p.trim().length > 0); +export async function multipartResponse( + convo: Conversation, + content: string | RichText, +) { + let parts: (string | RichText)[]; + + if (content instanceof RichText) { + if (exceedsGraphemes(content)) { + // If RichText exceeds grapheme limit, convert to plain text for splitting + parts = splitResponse(content.text); + } else { + // Otherwise, send the RichText directly as a single part + parts = [content]; + } + } else { + // If content is a string, behave as before + parts = splitResponse(content); + } for (const segment of parts) { await convo.sendMessage({