aero/src/utils/conversation.ts

235 lines
6 KiB
TypeScript

import {
type ChatMessage,
type Conversation,
graphemeLength,
} from "@skyware/bot";
import * as yaml from "js-yaml";
import db from "../db";
import { conversations, messages } from "../db/schema";
import { and, eq } from "drizzle-orm";
import { env } from "../env";
import { bot, MAX_GRAPHEMES } from "../core";
import { parsePost, parsePostImages, traverseThread } from "./post";
/*
Utilities
*/
const getUserDid = (convo: Conversation) =>
convo.members.find((actor) => actor.did != env.DID)!;
function generateRevision(bytes = 8) {
const array = new Uint8Array(bytes);
crypto.getRandomValues(array);
return Array.from(array, (b) => b.toString(16).padStart(2, "0")).join("");
}
/*
Conversations
*/
async function initConvo(convo: Conversation, initialMessage: ChatMessage) {
const user = getUserDid(convo);
const postUri = await parseMessagePostUri(initialMessage);
if (!postUri) {
convo.sendMessage({
text:
"Please send a post for me to make sense of the noise for you.",
});
throw new Error("No post reference in initial message.");
}
return await db.transaction(async (tx) => {
const [_convo] = await tx
.insert(conversations)
.values({
id: convo.id,
did: user.did,
postUri,
revision: generateRevision(),
})
.returning();
if (!_convo) {
throw new Error("Error during database transaction");
}
await tx
.insert(messages)
.values({
conversationId: _convo.id,
did: user.did,
postUri,
revision: _convo.revision,
text: initialMessage.text,
});
return _convo!;
});
}
async function getConvo(convoId: string) {
const [convo] = await db
.select()
.from(conversations)
.where(eq(conversations.id, convoId))
.limit(1);
return convo;
}
export async function parseConversation(
convo: Conversation,
latestMessage: ChatMessage,
) {
let row = await getConvo(convo.id);
if (!row) {
row = await initConvo(convo, latestMessage);
} else {
const postUri = await parseMessagePostUri(latestMessage);
if (postUri) {
const [updatedRow] = await db
.update(conversations)
.set({
postUri,
revision: generateRevision(),
})
.returning();
if (!updatedRow) {
throw new Error("Failed to update conversation in database");
}
row = updatedRow;
}
await db
.insert(messages)
.values({
conversationId: convo.id,
did: getUserDid(convo).did,
postUri: row.postUri,
revision: row.revision,
text: latestMessage!.text,
});
}
const post = await bot.getPost(row.postUri);
const convoMessages = await getRelevantMessages(row!);
let parseResult = null;
try {
parseResult = {
context: yaml.dump({
post: await parsePost(post, true),
}),
messages: convoMessages.map((message) => {
const role = message.did == env.DID ? "model" : "user";
return {
role,
parts: [
{
text: message.text,
},
],
};
}),
};
} catch (e) {
convo.sendMessage({
text:
"Sorry, I ran into an issue analyzing that post. Please try again.",
});
throw new Error("Failed to parse conversation");
}
return parseResult;
}
/*
Messages
*/
async function parseMessagePostUri(message: ChatMessage) {
if (!message.embed) return null;
const post = message.embed;
return post.uri;
}
async function getRelevantMessages(convo: typeof conversations.$inferSelect) {
const convoMessages = await db
.select()
.from(messages)
.where(
and(
eq(messages.conversationId, convo.id),
eq(messages.postUri, convo.postUri),
eq(messages.revision, convo.revision),
),
)
.limit(15);
return convoMessages;
}
export async function saveMessage(
convo: Conversation,
did: string,
text: string,
) {
const _convo = await getConvo(convo.id);
if (!_convo) {
throw new Error("Failed to find conversation with ID: " + convo.id);
}
await db
.insert(messages)
.values({
conversationId: _convo.id,
postUri: _convo.postUri,
revision: _convo.revision,
did,
text,
});
}
/*
Reponse Utilities
*/
export function exceedsGraphemes(content: string) {
return graphemeLength(content) > MAX_GRAPHEMES;
}
export function splitResponse(text: string): string[] {
const words = text.split(" ");
const chunks: string[] = [];
let currentChunk = "";
for (const word of words) {
if (currentChunk.length + word.length + 1 < MAX_GRAPHEMES - 10) {
currentChunk += ` ${word}`;
} else {
chunks.push(currentChunk.trim());
currentChunk = word;
}
}
if (currentChunk.trim()) {
chunks.push(currentChunk.trim());
}
const total = chunks.length;
if (total <= 1) return [text];
return chunks.map((chunk, i) => `(${i + 1}/${total}) ${chunk}`);
}
export async function multipartResponse(convo: Conversation, content: string) {
const parts = splitResponse(content).filter((p) => p.trim().length > 0);
for (const segment of parts) {
await convo.sendMessage({
text: segment,
});
}
}