feat: add citations when response uses Google Search tool

This commit is contained in:
Index 2025-12-22 16:23:37 -06:00
parent 92bdb975d9
commit 9b3be5d52a
3 changed files with 103 additions and 14 deletions

View file

@ -1,5 +1,5 @@
import modelPrompt from "../model/prompt.txt";
import { ChatMessage, Conversation } from "@skyware/bot";
import { ChatMessage, Conversation, RichText } from "@skyware/bot";
import * as c from "../core";
import * as tools from "../tools";
import consola from "consola";
@ -37,7 +37,7 @@ async function generateAIResponse(parsedContext: string, messages: {
parts: [
{
text: modelPrompt
.replace("{{ handle }}", env.HANDLE),
.replace("$handle", env.HANDLE),
},
],
},
@ -102,6 +102,66 @@ async function generateAIResponse(parsedContext: string, messages: {
return inference;
}
function addCitations(
inference: Awaited<ReturnType<typeof c.ai.models.generateContent>>,
) {
let originalText = inference.text ?? "";
if (!inference.candidates) {
return originalText;
}
const supports = inference.candidates[0]?.groundingMetadata
?.groundingSupports;
const chunks = inference.candidates[0]?.groundingMetadata?.groundingChunks;
const richText = new RichText();
if (!supports || !chunks || originalText === "") {
return richText.addText(originalText);
}
const sortedSupports = [...supports].sort(
(a, b) => (b.segment?.endIndex ?? 0) - (a.segment?.endIndex ?? 0),
);
let currentText = originalText;
for (const support of sortedSupports) {
const endIndex = support.segment?.endIndex;
if (endIndex === undefined || !support.groundingChunkIndices?.length) {
continue;
}
const citationLinks = support.groundingChunkIndices
.map((i) => {
const uri = chunks[i]?.web?.uri;
if (uri) {
return { index: i + 1, uri };
}
return null;
})
.filter(Boolean);
if (citationLinks.length > 0) {
richText.addText(currentText.slice(endIndex));
citationLinks.forEach((citation, idx) => {
if (citation) {
richText.addLink(`[${citation.index}]`, citation.uri);
if (idx < citationLinks.length - 1) {
richText.addText(", ");
}
}
});
currentText = currentText.slice(0, endIndex);
}
}
richText.addText(currentText);
return richText;
}
export async function handler(message: ChatMessage): Promise<void> {
const conversation = await message.getConversation();
// ? Conversation should always be able to be found, but just in case:
@ -162,16 +222,17 @@ export async function handler(message: ChatMessage): Promise<void> {
}
const responseText = inference.text;
const responseWithCitations = addCitations(inference);
if (responseText) {
logger.success("Generated text:", inference.text);
saveMessage(conversation, env.DID, inference.text!);
if (responseWithCitations) {
logger.success("Generated text:", responseText);
saveMessage(conversation, env.DID, responseText!);
if (exceedsGraphemes(responseText)) {
multipartResponse(conversation, responseText);
if (exceedsGraphemes(responseWithCitations)) {
multipartResponse(conversation, responseWithCitations);
} else {
conversation.sendMessage({
text: responseText,
text: responseWithCitations,
});
}
}

View file

@ -1,7 +1,7 @@
You are Aero, a neutral and helpful assistant on Bluesky.
Your job is to give clear, factual, and concise explanations or context about posts users send you.
Handle: {{ handle }}
Handle: $handle
Guidelines:

View file

@ -2,6 +2,7 @@ import {
type ChatMessage,
type Conversation,
graphemeLength,
RichText,
} from "@skyware/bot";
import * as yaml from "js-yaml";
import db from "../db";
@ -61,7 +62,11 @@ async function initConvo(convo: Conversation, initialMessage: ChatMessage) {
did: user.did,
postUri,
revision: _convo.revision,
text: initialMessage.text,
text:
!initialMessage.text ||
initialMessage.text.trim().length == 0
? "Explain this post."
: initialMessage.text,
});
return _convo!;
@ -110,7 +115,11 @@ export async function parseConversation(
did: getUserDid(convo).did,
postUri: row.postUri,
revision: row.revision,
text: latestMessage!.text,
text: postUri &&
(!latestMessage.text ||
latestMessage.text.trim().length == 0)
? "Explain this post."
: latestMessage.text,
});
}
@ -196,7 +205,10 @@ export async function saveMessage(
/*
Reponse Utilities
*/
export function exceedsGraphemes(content: string) {
export function exceedsGraphemes(content: string | RichText) {
if (content instanceof RichText) {
return graphemeLength(content.text) > MAX_GRAPHEMES;
}
return graphemeLength(content) > MAX_GRAPHEMES;
}
@ -224,8 +236,24 @@ export function splitResponse(text: string): string[] {
return chunks.map((chunk, i) => `(${i + 1}/${total}) ${chunk}`);
}
export async function multipartResponse(convo: Conversation, content: string) {
const parts = splitResponse(content).filter((p) => p.trim().length > 0);
export async function multipartResponse(
convo: Conversation,
content: string | RichText,
) {
let parts: (string | RichText)[];
if (content instanceof RichText) {
if (exceedsGraphemes(content)) {
// If RichText exceeds grapheme limit, convert to plain text for splitting
parts = splitResponse(content.text);
} else {
// Otherwise, send the RichText directly as a single part
parts = [content];
}
} else {
// If content is a string, behave as before
parts = splitResponse(content);
}
for (const segment of parts) {
await convo.sendMessage({