feat: add citations when response uses Google Search tool

This commit is contained in:
Index 2025-12-22 16:23:37 -06:00
parent 92bdb975d9
commit 9b3be5d52a
3 changed files with 103 additions and 14 deletions

View file

@ -1,5 +1,5 @@
import modelPrompt from "../model/prompt.txt"; import modelPrompt from "../model/prompt.txt";
import { ChatMessage, Conversation } from "@skyware/bot"; import { ChatMessage, Conversation, RichText } from "@skyware/bot";
import * as c from "../core"; import * as c from "../core";
import * as tools from "../tools"; import * as tools from "../tools";
import consola from "consola"; import consola from "consola";
@ -37,7 +37,7 @@ async function generateAIResponse(parsedContext: string, messages: {
parts: [ parts: [
{ {
text: modelPrompt text: modelPrompt
.replace("{{ handle }}", env.HANDLE), .replace("$handle", env.HANDLE),
}, },
], ],
}, },
@ -102,6 +102,66 @@ async function generateAIResponse(parsedContext: string, messages: {
return inference; return inference;
} }
function addCitations(
inference: Awaited<ReturnType<typeof c.ai.models.generateContent>>,
) {
let originalText = inference.text ?? "";
if (!inference.candidates) {
return originalText;
}
const supports = inference.candidates[0]?.groundingMetadata
?.groundingSupports;
const chunks = inference.candidates[0]?.groundingMetadata?.groundingChunks;
const richText = new RichText();
if (!supports || !chunks || originalText === "") {
return richText.addText(originalText);
}
const sortedSupports = [...supports].sort(
(a, b) => (b.segment?.endIndex ?? 0) - (a.segment?.endIndex ?? 0),
);
let currentText = originalText;
for (const support of sortedSupports) {
const endIndex = support.segment?.endIndex;
if (endIndex === undefined || !support.groundingChunkIndices?.length) {
continue;
}
const citationLinks = support.groundingChunkIndices
.map((i) => {
const uri = chunks[i]?.web?.uri;
if (uri) {
return { index: i + 1, uri };
}
return null;
})
.filter(Boolean);
if (citationLinks.length > 0) {
richText.addText(currentText.slice(endIndex));
citationLinks.forEach((citation, idx) => {
if (citation) {
richText.addLink(`[${citation.index}]`, citation.uri);
if (idx < citationLinks.length - 1) {
richText.addText(", ");
}
}
});
currentText = currentText.slice(0, endIndex);
}
}
richText.addText(currentText);
return richText;
}
export async function handler(message: ChatMessage): Promise<void> { export async function handler(message: ChatMessage): Promise<void> {
const conversation = await message.getConversation(); const conversation = await message.getConversation();
// ? Conversation should always be able to be found, but just in case: // ? Conversation should always be able to be found, but just in case:
@ -162,16 +222,17 @@ export async function handler(message: ChatMessage): Promise<void> {
} }
const responseText = inference.text; const responseText = inference.text;
const responseWithCitations = addCitations(inference);
if (responseText) { if (responseWithCitations) {
logger.success("Generated text:", inference.text); logger.success("Generated text:", responseText);
saveMessage(conversation, env.DID, inference.text!); saveMessage(conversation, env.DID, responseText!);
if (exceedsGraphemes(responseText)) { if (exceedsGraphemes(responseWithCitations)) {
multipartResponse(conversation, responseText); multipartResponse(conversation, responseWithCitations);
} else { } else {
conversation.sendMessage({ conversation.sendMessage({
text: responseText, text: responseWithCitations,
}); });
} }
} }

View file

@ -1,7 +1,7 @@
You are Aero, a neutral and helpful assistant on Bluesky. You are Aero, a neutral and helpful assistant on Bluesky.
Your job is to give clear, factual, and concise explanations or context about posts users send you. Your job is to give clear, factual, and concise explanations or context about posts users send you.
Handle: {{ handle }} Handle: $handle
Guidelines: Guidelines:

View file

@ -2,6 +2,7 @@ import {
type ChatMessage, type ChatMessage,
type Conversation, type Conversation,
graphemeLength, graphemeLength,
RichText,
} from "@skyware/bot"; } from "@skyware/bot";
import * as yaml from "js-yaml"; import * as yaml from "js-yaml";
import db from "../db"; import db from "../db";
@ -61,7 +62,11 @@ async function initConvo(convo: Conversation, initialMessage: ChatMessage) {
did: user.did, did: user.did,
postUri, postUri,
revision: _convo.revision, revision: _convo.revision,
text: initialMessage.text, text:
!initialMessage.text ||
initialMessage.text.trim().length == 0
? "Explain this post."
: initialMessage.text,
}); });
return _convo!; return _convo!;
@ -110,7 +115,11 @@ export async function parseConversation(
did: getUserDid(convo).did, did: getUserDid(convo).did,
postUri: row.postUri, postUri: row.postUri,
revision: row.revision, revision: row.revision,
text: latestMessage!.text, text: postUri &&
(!latestMessage.text ||
latestMessage.text.trim().length == 0)
? "Explain this post."
: latestMessage.text,
}); });
} }
@ -196,7 +205,10 @@ export async function saveMessage(
/* /*
Reponse Utilities Reponse Utilities
*/ */
export function exceedsGraphemes(content: string) { export function exceedsGraphemes(content: string | RichText) {
if (content instanceof RichText) {
return graphemeLength(content.text) > MAX_GRAPHEMES;
}
return graphemeLength(content) > MAX_GRAPHEMES; return graphemeLength(content) > MAX_GRAPHEMES;
} }
@ -224,8 +236,24 @@ export function splitResponse(text: string): string[] {
return chunks.map((chunk, i) => `(${i + 1}/${total}) ${chunk}`); return chunks.map((chunk, i) => `(${i + 1}/${total}) ${chunk}`);
} }
export async function multipartResponse(convo: Conversation, content: string) { export async function multipartResponse(
const parts = splitResponse(content).filter((p) => p.trim().length > 0); convo: Conversation,
content: string | RichText,
) {
let parts: (string | RichText)[];
if (content instanceof RichText) {
if (exceedsGraphemes(content)) {
// If RichText exceeds grapheme limit, convert to plain text for splitting
parts = splitResponse(content.text);
} else {
// Otherwise, send the RichText directly as a single part
parts = [content];
}
} else {
// If content is a string, behave as before
parts = splitResponse(content);
}
for (const segment of parts) { for (const segment of parts) {
await convo.sendMessage({ await convo.sendMessage({