update inference structure (closes #3)

This commit is contained in:
Index 2025-11-06 02:02:45 -06:00
parent dda2552126
commit 2fbf412df4
3 changed files with 61 additions and 69 deletions

View file

@ -12,5 +12,6 @@ services:
- "HANDLE=${HANDLE:?}" - "HANDLE=${HANDLE:?}"
- "APP_PASSWORD=${APP_PASSWORD:?}" - "APP_PASSWORD=${APP_PASSWORD:?}"
- "GEMINI_API_KEY=${GEMINI_API_KEY:?}" - "GEMINI_API_KEY=${GEMINI_API_KEY:?}"
- "USE_JETSTREAM=${USE_JETSTREAM:-false}"
volumes: volumes:
- aero_db:/sqlite.db - aero_db:/sqlite.db

View file

@ -18,7 +18,12 @@ const logger = consola.withTag("Message Handler");
type SupportedFunctionCall = typeof c.SUPPORTED_FUNCTION_CALLS[number]; type SupportedFunctionCall = typeof c.SUPPORTED_FUNCTION_CALLS[number];
async function generateAIResponse(parsedConversation: string) { async function generateAIResponse(parsedContext: string, messages: {
role: string;
parts: {
text: string;
}[];
}[]) {
const config = { const config = {
model: env.GEMINI_MODEL, model: env.GEMINI_MODEL,
config: { config: {
@ -37,16 +42,14 @@ async function generateAIResponse(parsedConversation: string) {
], ],
}, },
{ {
role: "user" as const, role: "model" as const,
parts: [ parts: [
{ {
text: text: parsedContext,
`Below is the yaml for the current conversation. The last message is the one to respond to. The post is the current one you are meant to be analyzing.
${parsedConversation}`,
}, },
], ],
}, },
...messages,
]; ];
let inference = await c.ai.models.generateContent({ let inference = await c.ai.models.generateContent({
@ -99,19 +102,6 @@ ${parsedConversation}`,
return inference; return inference;
} }
async function sendResponse(
conversation: Conversation,
text: string,
): Promise<void> {
if (exceedsGraphemes(text)) {
multipartResponse(conversation, text);
} else {
conversation.sendMessage({
text,
});
}
}
export async function handler(message: ChatMessage): Promise<void> { export async function handler(message: ChatMessage): Promise<void> {
const conversation = await message.getConversation(); const conversation = await message.getConversation();
// ? Conversation should always be able to be found, but just in case: // ? Conversation should always be able to be found, but just in case:
@ -132,27 +122,26 @@ export async function handler(message: ChatMessage): Promise<void> {
return; return;
} }
const today = new Date(); if (message.senderDid != env.ADMIN_DID) {
today.setHours(0, 0, 0, 0); const todayStart = new Date();
const tomorrow = new Date(today); todayStart.setHours(0, 0, 0, 0);
tomorrow.setDate(tomorrow.getDate() + 1);
const dailyCount = await db const dailyCount = await db
.select({ count: count(messages.id) }) .select({ count: count(messages.id) })
.from(messages) .from(messages)
.where( .where(
and( and(
eq(messages.did, message.senderDid), eq(messages.did, message.senderDid),
gte(messages.created_at, today), gte(messages.created_at, todayStart),
lt(messages.created_at, tomorrow), ),
), );
);
if (dailyCount[0]!.count >= env.DAILY_QUERY_LIMIT) { if (dailyCount[0]!.count >= env.DAILY_QUERY_LIMIT) {
conversation.sendMessage({ conversation.sendMessage({
text: c.QUOTA_EXCEEDED_MESSAGE, text: c.QUOTA_EXCEEDED_MESSAGE,
}); });
return; return;
}
} }
logger.success("Found conversation"); logger.success("Found conversation");
@ -160,12 +149,13 @@ export async function handler(message: ChatMessage): Promise<void> {
text: "...", text: "...",
}); });
const parsedConversation = await parseConversation(conversation); const parsedConversation = await parseConversation(conversation, message);
logger.info("Parsed conversation: ", parsedConversation);
try { try {
const inference = await generateAIResponse(parsedConversation); const inference = await generateAIResponse(
parsedConversation.context,
parsedConversation.messages,
);
if (!inference) { if (!inference) {
throw new Error("Failed to generate text. Returned undefined."); throw new Error("Failed to generate text. Returned undefined.");
} }
@ -176,7 +166,13 @@ export async function handler(message: ChatMessage): Promise<void> {
logger.success("Generated text:", inference.text); logger.success("Generated text:", inference.text);
saveMessage(conversation, env.DID, inference.text!); saveMessage(conversation, env.DID, inference.text!);
await sendResponse(conversation, responseText); if (exceedsGraphemes(responseText)) {
multipartResponse(conversation, responseText);
} else {
conversation.sendMessage({
text: responseText,
});
}
} }
} catch (error) { } catch (error) {
logger.error("Error in post handler:", error); logger.error("Error in post handler:", error);

View file

@ -14,9 +14,6 @@ import { parsePost, parsePostImages, traverseThread } from "./post";
/* /*
Utilities Utilities
*/ */
const resolveDid = (convo: Conversation, did: string) =>
convo.members.find((actor) => actor.did == did)!;
const getUserDid = (convo: Conversation) => const getUserDid = (convo: Conversation) =>
convo.members.find((actor) => actor.did != env.DID)!; convo.members.find((actor) => actor.did != env.DID)!;
@ -29,16 +26,9 @@ function generateRevision(bytes = 8) {
/* /*
Conversations Conversations
*/ */
async function initConvo(convo: Conversation) { async function initConvo(convo: Conversation, initialMessage: ChatMessage) {
const user = getUserDid(convo); const user = getUserDid(convo);
const initialMessage = (await convo.getMessages()).messages[0] as
| ChatMessage
| undefined;
if (!initialMessage) {
throw new Error("Failed to get initial message of conversation");
}
const postUri = await parseMessagePostUri(initialMessage); const postUri = await parseMessagePostUri(initialMessage);
if (!postUri) { if (!postUri) {
convo.sendMessage({ convo.sendMessage({
@ -87,14 +77,14 @@ async function getConvo(convoId: string) {
return convo; return convo;
} }
export async function parseConversation(convo: Conversation) { export async function parseConversation(
convo: Conversation,
latestMessage: ChatMessage,
) {
let row = await getConvo(convo.id); let row = await getConvo(convo.id);
if (!row) { if (!row) {
row = await initConvo(convo); row = await initConvo(convo, latestMessage);
} else { } else {
const latestMessage = (await convo.getMessages())
.messages[0] as ChatMessage;
const postUri = await parseMessagePostUri(latestMessage); const postUri = await parseMessagePostUri(latestMessage);
if (postUri) { if (postUri) {
const [updatedRow] = await db const [updatedRow] = await db
@ -128,19 +118,23 @@ export async function parseConversation(convo: Conversation) {
let parseResult = null; let parseResult = null;
try { try {
parseResult = yaml.dump({ parseResult = {
post: await parsePost(post, true), context: yaml.dump({
post: await parsePost(post, true),
}),
messages: convoMessages.map((message) => { messages: convoMessages.map((message) => {
const profile = resolveDid(convo, message.did); const role = message.did == env.DID ? "model" : "user";
return { return {
user: profile.displayName role,
? `${profile.displayName} (${profile.handle})` parts: [
: `Handle: ${profile.handle}`, {
text: message.text, text: message.text,
},
],
}; };
}), }),
}); };
} catch (e) { } catch (e) {
convo.sendMessage({ convo.sendMessage({
text: text:
@ -169,7 +163,8 @@ async function getRelevantMessages(convo: typeof conversations.$inferSelect) {
.where( .where(
and( and(
eq(messages.conversationId, convo.id), eq(messages.conversationId, convo.id),
eq(messages.postUri, convo!.postUri), eq(messages.postUri, convo.postUri),
eq(messages.revision, convo.revision),
), ),
) )
.limit(15); .limit(15);
@ -192,7 +187,7 @@ export async function saveMessage(
.values({ .values({
conversationId: _convo.id, conversationId: _convo.id,
postUri: _convo.postUri, postUri: _convo.postUri,
revision: _convo.postUri, revision: _convo.revision,
did, did,
text, text,
}); });