feat: add citations when response uses Google Search tool
This commit is contained in:
parent
92bdb975d9
commit
9b3be5d52a
3 changed files with 103 additions and 14 deletions
|
|
@ -1,5 +1,5 @@
|
|||
import modelPrompt from "../model/prompt.txt";
|
||||
import { ChatMessage, Conversation } from "@skyware/bot";
|
||||
import { ChatMessage, Conversation, RichText } from "@skyware/bot";
|
||||
import * as c from "../core";
|
||||
import * as tools from "../tools";
|
||||
import consola from "consola";
|
||||
|
|
@ -37,7 +37,7 @@ async function generateAIResponse(parsedContext: string, messages: {
|
|||
parts: [
|
||||
{
|
||||
text: modelPrompt
|
||||
.replace("{{ handle }}", env.HANDLE),
|
||||
.replace("$handle", env.HANDLE),
|
||||
},
|
||||
],
|
||||
},
|
||||
|
|
@ -102,6 +102,66 @@ async function generateAIResponse(parsedContext: string, messages: {
|
|||
return inference;
|
||||
}
|
||||
|
||||
function addCitations(
|
||||
inference: Awaited<ReturnType<typeof c.ai.models.generateContent>>,
|
||||
) {
|
||||
let originalText = inference.text ?? "";
|
||||
if (!inference.candidates) {
|
||||
return originalText;
|
||||
}
|
||||
const supports = inference.candidates[0]?.groundingMetadata
|
||||
?.groundingSupports;
|
||||
const chunks = inference.candidates[0]?.groundingMetadata?.groundingChunks;
|
||||
|
||||
const richText = new RichText();
|
||||
|
||||
if (!supports || !chunks || originalText === "") {
|
||||
return richText.addText(originalText);
|
||||
}
|
||||
|
||||
const sortedSupports = [...supports].sort(
|
||||
(a, b) => (b.segment?.endIndex ?? 0) - (a.segment?.endIndex ?? 0),
|
||||
);
|
||||
|
||||
let currentText = originalText;
|
||||
|
||||
for (const support of sortedSupports) {
|
||||
const endIndex = support.segment?.endIndex;
|
||||
if (endIndex === undefined || !support.groundingChunkIndices?.length) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const citationLinks = support.groundingChunkIndices
|
||||
.map((i) => {
|
||||
const uri = chunks[i]?.web?.uri;
|
||||
if (uri) {
|
||||
return { index: i + 1, uri };
|
||||
}
|
||||
return null;
|
||||
})
|
||||
.filter(Boolean);
|
||||
|
||||
if (citationLinks.length > 0) {
|
||||
richText.addText(currentText.slice(endIndex));
|
||||
|
||||
citationLinks.forEach((citation, idx) => {
|
||||
if (citation) {
|
||||
richText.addLink(`[${citation.index}]`, citation.uri);
|
||||
if (idx < citationLinks.length - 1) {
|
||||
richText.addText(", ");
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
currentText = currentText.slice(0, endIndex);
|
||||
}
|
||||
}
|
||||
|
||||
richText.addText(currentText);
|
||||
|
||||
return richText;
|
||||
}
|
||||
|
||||
export async function handler(message: ChatMessage): Promise<void> {
|
||||
const conversation = await message.getConversation();
|
||||
// ? Conversation should always be able to be found, but just in case:
|
||||
|
|
@ -162,16 +222,17 @@ export async function handler(message: ChatMessage): Promise<void> {
|
|||
}
|
||||
|
||||
const responseText = inference.text;
|
||||
const responseWithCitations = addCitations(inference);
|
||||
|
||||
if (responseText) {
|
||||
logger.success("Generated text:", inference.text);
|
||||
saveMessage(conversation, env.DID, inference.text!);
|
||||
if (responseWithCitations) {
|
||||
logger.success("Generated text:", responseText);
|
||||
saveMessage(conversation, env.DID, responseText!);
|
||||
|
||||
if (exceedsGraphemes(responseText)) {
|
||||
multipartResponse(conversation, responseText);
|
||||
if (exceedsGraphemes(responseWithCitations)) {
|
||||
multipartResponse(conversation, responseWithCitations);
|
||||
} else {
|
||||
conversation.sendMessage({
|
||||
text: responseText,
|
||||
text: responseWithCitations,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
You are Aero, a neutral and helpful assistant on Bluesky.
|
||||
Your job is to give clear, factual, and concise explanations or context about posts users send you.
|
||||
|
||||
Handle: {{ handle }}
|
||||
Handle: $handle
|
||||
|
||||
Guidelines:
|
||||
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@ import {
|
|||
type ChatMessage,
|
||||
type Conversation,
|
||||
graphemeLength,
|
||||
RichText,
|
||||
} from "@skyware/bot";
|
||||
import * as yaml from "js-yaml";
|
||||
import db from "../db";
|
||||
|
|
@ -61,7 +62,11 @@ async function initConvo(convo: Conversation, initialMessage: ChatMessage) {
|
|||
did: user.did,
|
||||
postUri,
|
||||
revision: _convo.revision,
|
||||
text: initialMessage.text,
|
||||
text:
|
||||
!initialMessage.text ||
|
||||
initialMessage.text.trim().length == 0
|
||||
? "Explain this post."
|
||||
: initialMessage.text,
|
||||
});
|
||||
|
||||
return _convo!;
|
||||
|
|
@ -110,7 +115,11 @@ export async function parseConversation(
|
|||
did: getUserDid(convo).did,
|
||||
postUri: row.postUri,
|
||||
revision: row.revision,
|
||||
text: latestMessage!.text,
|
||||
text: postUri &&
|
||||
(!latestMessage.text ||
|
||||
latestMessage.text.trim().length == 0)
|
||||
? "Explain this post."
|
||||
: latestMessage.text,
|
||||
});
|
||||
}
|
||||
|
||||
|
|
@ -196,7 +205,10 @@ export async function saveMessage(
|
|||
/*
|
||||
Reponse Utilities
|
||||
*/
|
||||
export function exceedsGraphemes(content: string) {
|
||||
export function exceedsGraphemes(content: string | RichText) {
|
||||
if (content instanceof RichText) {
|
||||
return graphemeLength(content.text) > MAX_GRAPHEMES;
|
||||
}
|
||||
return graphemeLength(content) > MAX_GRAPHEMES;
|
||||
}
|
||||
|
||||
|
|
@ -224,8 +236,24 @@ export function splitResponse(text: string): string[] {
|
|||
return chunks.map((chunk, i) => `(${i + 1}/${total}) ${chunk}`);
|
||||
}
|
||||
|
||||
export async function multipartResponse(convo: Conversation, content: string) {
|
||||
const parts = splitResponse(content).filter((p) => p.trim().length > 0);
|
||||
export async function multipartResponse(
|
||||
convo: Conversation,
|
||||
content: string | RichText,
|
||||
) {
|
||||
let parts: (string | RichText)[];
|
||||
|
||||
if (content instanceof RichText) {
|
||||
if (exceedsGraphemes(content)) {
|
||||
// If RichText exceeds grapheme limit, convert to plain text for splitting
|
||||
parts = splitResponse(content.text);
|
||||
} else {
|
||||
// Otherwise, send the RichText directly as a single part
|
||||
parts = [content];
|
||||
}
|
||||
} else {
|
||||
// If content is a string, behave as before
|
||||
parts = splitResponse(content);
|
||||
}
|
||||
|
||||
for (const segment of parts) {
|
||||
await convo.sendMessage({
|
||||
|
|
|
|||
Loading…
Reference in a new issue