update inference structure (closes #3)

This commit is contained in:
Index 2025-11-06 02:02:45 -06:00
parent dda2552126
commit 2fbf412df4
3 changed files with 61 additions and 69 deletions

View file

@ -12,5 +12,6 @@ services:
- "HANDLE=${HANDLE:?}"
- "APP_PASSWORD=${APP_PASSWORD:?}"
- "GEMINI_API_KEY=${GEMINI_API_KEY:?}"
- "USE_JETSTREAM=${USE_JETSTREAM:-false}"
volumes:
- aero_db:/sqlite.db

View file

@ -18,7 +18,12 @@ const logger = consola.withTag("Message Handler");
type SupportedFunctionCall = typeof c.SUPPORTED_FUNCTION_CALLS[number];
async function generateAIResponse(parsedConversation: string) {
async function generateAIResponse(parsedContext: string, messages: {
role: string;
parts: {
text: string;
}[];
}[]) {
const config = {
model: env.GEMINI_MODEL,
config: {
@ -37,16 +42,14 @@ async function generateAIResponse(parsedConversation: string) {
],
},
{
role: "user" as const,
role: "model" as const,
parts: [
{
text:
`Below is the yaml for the current conversation. The last message is the one to respond to. The post is the current one you are meant to be analyzing.
${parsedConversation}`,
text: parsedContext,
},
],
},
...messages,
];
let inference = await c.ai.models.generateContent({
@ -99,19 +102,6 @@ ${parsedConversation}`,
return inference;
}
async function sendResponse(
conversation: Conversation,
text: string,
): Promise<void> {
if (exceedsGraphemes(text)) {
multipartResponse(conversation, text);
} else {
conversation.sendMessage({
text,
});
}
}
export async function handler(message: ChatMessage): Promise<void> {
const conversation = await message.getConversation();
// ? Conversation should always be able to be found, but just in case:
@ -132,27 +122,26 @@ export async function handler(message: ChatMessage): Promise<void> {
return;
}
const today = new Date();
today.setHours(0, 0, 0, 0);
const tomorrow = new Date(today);
tomorrow.setDate(tomorrow.getDate() + 1);
if (message.senderDid != env.ADMIN_DID) {
const todayStart = new Date();
todayStart.setHours(0, 0, 0, 0);
const dailyCount = await db
.select({ count: count(messages.id) })
.from(messages)
.where(
and(
eq(messages.did, message.senderDid),
gte(messages.created_at, today),
lt(messages.created_at, tomorrow),
),
);
const dailyCount = await db
.select({ count: count(messages.id) })
.from(messages)
.where(
and(
eq(messages.did, message.senderDid),
gte(messages.created_at, todayStart),
),
);
if (dailyCount[0]!.count >= env.DAILY_QUERY_LIMIT) {
conversation.sendMessage({
text: c.QUOTA_EXCEEDED_MESSAGE,
});
return;
if (dailyCount[0]!.count >= env.DAILY_QUERY_LIMIT) {
conversation.sendMessage({
text: c.QUOTA_EXCEEDED_MESSAGE,
});
return;
}
}
logger.success("Found conversation");
@ -160,12 +149,13 @@ export async function handler(message: ChatMessage): Promise<void> {
text: "...",
});
const parsedConversation = await parseConversation(conversation);
logger.info("Parsed conversation: ", parsedConversation);
const parsedConversation = await parseConversation(conversation, message);
try {
const inference = await generateAIResponse(parsedConversation);
const inference = await generateAIResponse(
parsedConversation.context,
parsedConversation.messages,
);
if (!inference) {
throw new Error("Failed to generate text. Returned undefined.");
}
@ -176,7 +166,13 @@ export async function handler(message: ChatMessage): Promise<void> {
logger.success("Generated text:", inference.text);
saveMessage(conversation, env.DID, inference.text!);
await sendResponse(conversation, responseText);
if (exceedsGraphemes(responseText)) {
multipartResponse(conversation, responseText);
} else {
conversation.sendMessage({
text: responseText,
});
}
}
} catch (error) {
logger.error("Error in post handler:", error);

View file

@ -14,9 +14,6 @@ import { parsePost, parsePostImages, traverseThread } from "./post";
/*
Utilities
*/
const resolveDid = (convo: Conversation, did: string) =>
convo.members.find((actor) => actor.did == did)!;
const getUserDid = (convo: Conversation) =>
convo.members.find((actor) => actor.did != env.DID)!;
@ -29,16 +26,9 @@ function generateRevision(bytes = 8) {
/*
Conversations
*/
async function initConvo(convo: Conversation) {
async function initConvo(convo: Conversation, initialMessage: ChatMessage) {
const user = getUserDid(convo);
const initialMessage = (await convo.getMessages()).messages[0] as
| ChatMessage
| undefined;
if (!initialMessage) {
throw new Error("Failed to get initial message of conversation");
}
const postUri = await parseMessagePostUri(initialMessage);
if (!postUri) {
convo.sendMessage({
@ -87,14 +77,14 @@ async function getConvo(convoId: string) {
return convo;
}
export async function parseConversation(convo: Conversation) {
export async function parseConversation(
convo: Conversation,
latestMessage: ChatMessage,
) {
let row = await getConvo(convo.id);
if (!row) {
row = await initConvo(convo);
row = await initConvo(convo, latestMessage);
} else {
const latestMessage = (await convo.getMessages())
.messages[0] as ChatMessage;
const postUri = await parseMessagePostUri(latestMessage);
if (postUri) {
const [updatedRow] = await db
@ -128,19 +118,23 @@ export async function parseConversation(convo: Conversation) {
let parseResult = null;
try {
parseResult = yaml.dump({
post: await parsePost(post, true),
parseResult = {
context: yaml.dump({
post: await parsePost(post, true),
}),
messages: convoMessages.map((message) => {
const profile = resolveDid(convo, message.did);
const role = message.did == env.DID ? "model" : "user";
return {
user: profile.displayName
? `${profile.displayName} (${profile.handle})`
: `Handle: ${profile.handle}`,
text: message.text,
role,
parts: [
{
text: message.text,
},
],
};
}),
});
};
} catch (e) {
convo.sendMessage({
text:
@ -169,7 +163,8 @@ async function getRelevantMessages(convo: typeof conversations.$inferSelect) {
.where(
and(
eq(messages.conversationId, convo.id),
eq(messages.postUri, convo!.postUri),
eq(messages.postUri, convo.postUri),
eq(messages.revision, convo.revision),
),
)
.limit(15);
@ -192,7 +187,7 @@ export async function saveMessage(
.values({
conversationId: _convo.id,
postUri: _convo.postUri,
revision: _convo.postUri,
revision: _convo.revision,
did,
text,
});