diff --git a/docker-compose.yml b/docker-compose.yml index a8b2412..3dd3bd9 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -12,5 +12,6 @@ services: - "HANDLE=${HANDLE:?}" - "APP_PASSWORD=${APP_PASSWORD:?}" - "GEMINI_API_KEY=${GEMINI_API_KEY:?}" + - "USE_JETSTREAM=${USE_JETSTREAM:-false}" volumes: - aero_db:/sqlite.db diff --git a/src/handlers/messages.ts b/src/handlers/messages.ts index be02328..47cd4c5 100644 --- a/src/handlers/messages.ts +++ b/src/handlers/messages.ts @@ -18,7 +18,12 @@ const logger = consola.withTag("Message Handler"); type SupportedFunctionCall = typeof c.SUPPORTED_FUNCTION_CALLS[number]; -async function generateAIResponse(parsedConversation: string) { +async function generateAIResponse(parsedContext: string, messages: { + role: string; + parts: { + text: string; + }[]; +}[]) { const config = { model: env.GEMINI_MODEL, config: { @@ -37,16 +42,14 @@ async function generateAIResponse(parsedConversation: string) { ], }, { - role: "user" as const, + role: "model" as const, parts: [ { - text: - `Below is the yaml for the current conversation. The last message is the one to respond to. The post is the current one you are meant to be analyzing. - -${parsedConversation}`, + text: parsedContext, }, ], }, + ...messages, ]; let inference = await c.ai.models.generateContent({ @@ -99,19 +102,6 @@ ${parsedConversation}`, return inference; } -async function sendResponse( - conversation: Conversation, - text: string, -): Promise { - if (exceedsGraphemes(text)) { - multipartResponse(conversation, text); - } else { - conversation.sendMessage({ - text, - }); - } -} - export async function handler(message: ChatMessage): Promise { const conversation = await message.getConversation(); // ? Conversation should always be able to be found, but just in case: @@ -132,27 +122,26 @@ export async function handler(message: ChatMessage): Promise { return; } - const today = new Date(); - today.setHours(0, 0, 0, 0); - const tomorrow = new Date(today); - tomorrow.setDate(tomorrow.getDate() + 1); + if (message.senderDid != env.ADMIN_DID) { + const todayStart = new Date(); + todayStart.setHours(0, 0, 0, 0); - const dailyCount = await db - .select({ count: count(messages.id) }) - .from(messages) - .where( - and( - eq(messages.did, message.senderDid), - gte(messages.created_at, today), - lt(messages.created_at, tomorrow), - ), - ); + const dailyCount = await db + .select({ count: count(messages.id) }) + .from(messages) + .where( + and( + eq(messages.did, message.senderDid), + gte(messages.created_at, todayStart), + ), + ); - if (dailyCount[0]!.count >= env.DAILY_QUERY_LIMIT) { - conversation.sendMessage({ - text: c.QUOTA_EXCEEDED_MESSAGE, - }); - return; + if (dailyCount[0]!.count >= env.DAILY_QUERY_LIMIT) { + conversation.sendMessage({ + text: c.QUOTA_EXCEEDED_MESSAGE, + }); + return; + } } logger.success("Found conversation"); @@ -160,12 +149,13 @@ export async function handler(message: ChatMessage): Promise { text: "...", }); - const parsedConversation = await parseConversation(conversation); - - logger.info("Parsed conversation: ", parsedConversation); + const parsedConversation = await parseConversation(conversation, message); try { - const inference = await generateAIResponse(parsedConversation); + const inference = await generateAIResponse( + parsedConversation.context, + parsedConversation.messages, + ); if (!inference) { throw new Error("Failed to generate text. Returned undefined."); } @@ -176,7 +166,13 @@ export async function handler(message: ChatMessage): Promise { logger.success("Generated text:", inference.text); saveMessage(conversation, env.DID, inference.text!); - await sendResponse(conversation, responseText); + if (exceedsGraphemes(responseText)) { + multipartResponse(conversation, responseText); + } else { + conversation.sendMessage({ + text: responseText, + }); + } } } catch (error) { logger.error("Error in post handler:", error); diff --git a/src/utils/conversation.ts b/src/utils/conversation.ts index 02d46a2..6fd5d45 100644 --- a/src/utils/conversation.ts +++ b/src/utils/conversation.ts @@ -14,9 +14,6 @@ import { parsePost, parsePostImages, traverseThread } from "./post"; /* Utilities */ -const resolveDid = (convo: Conversation, did: string) => - convo.members.find((actor) => actor.did == did)!; - const getUserDid = (convo: Conversation) => convo.members.find((actor) => actor.did != env.DID)!; @@ -29,16 +26,9 @@ function generateRevision(bytes = 8) { /* Conversations */ -async function initConvo(convo: Conversation) { +async function initConvo(convo: Conversation, initialMessage: ChatMessage) { const user = getUserDid(convo); - const initialMessage = (await convo.getMessages()).messages[0] as - | ChatMessage - | undefined; - if (!initialMessage) { - throw new Error("Failed to get initial message of conversation"); - } - const postUri = await parseMessagePostUri(initialMessage); if (!postUri) { convo.sendMessage({ @@ -87,14 +77,14 @@ async function getConvo(convoId: string) { return convo; } -export async function parseConversation(convo: Conversation) { +export async function parseConversation( + convo: Conversation, + latestMessage: ChatMessage, +) { let row = await getConvo(convo.id); if (!row) { - row = await initConvo(convo); + row = await initConvo(convo, latestMessage); } else { - const latestMessage = (await convo.getMessages()) - .messages[0] as ChatMessage; - const postUri = await parseMessagePostUri(latestMessage); if (postUri) { const [updatedRow] = await db @@ -128,19 +118,23 @@ export async function parseConversation(convo: Conversation) { let parseResult = null; try { - parseResult = yaml.dump({ - post: await parsePost(post, true), + parseResult = { + context: yaml.dump({ + post: await parsePost(post, true), + }), messages: convoMessages.map((message) => { - const profile = resolveDid(convo, message.did); + const role = message.did == env.DID ? "model" : "user"; return { - user: profile.displayName - ? `${profile.displayName} (${profile.handle})` - : `Handle: ${profile.handle}`, - text: message.text, + role, + parts: [ + { + text: message.text, + }, + ], }; }), - }); + }; } catch (e) { convo.sendMessage({ text: @@ -169,7 +163,8 @@ async function getRelevantMessages(convo: typeof conversations.$inferSelect) { .where( and( eq(messages.conversationId, convo.id), - eq(messages.postUri, convo!.postUri), + eq(messages.postUri, convo.postUri), + eq(messages.revision, convo.revision), ), ) .limit(15); @@ -192,7 +187,7 @@ export async function saveMessage( .values({ conversationId: _convo.id, postUri: _convo.postUri, - revision: _convo.postUri, + revision: _convo.revision, did, text, });