update inference structure (closes #3)
This commit is contained in:
parent
dda2552126
commit
2fbf412df4
3 changed files with 61 additions and 69 deletions
|
|
@ -12,5 +12,6 @@ services:
|
|||
- "HANDLE=${HANDLE:?}"
|
||||
- "APP_PASSWORD=${APP_PASSWORD:?}"
|
||||
- "GEMINI_API_KEY=${GEMINI_API_KEY:?}"
|
||||
- "USE_JETSTREAM=${USE_JETSTREAM:-false}"
|
||||
volumes:
|
||||
- aero_db:/sqlite.db
|
||||
|
|
|
|||
|
|
@ -18,7 +18,12 @@ const logger = consola.withTag("Message Handler");
|
|||
|
||||
type SupportedFunctionCall = typeof c.SUPPORTED_FUNCTION_CALLS[number];
|
||||
|
||||
async function generateAIResponse(parsedConversation: string) {
|
||||
async function generateAIResponse(parsedContext: string, messages: {
|
||||
role: string;
|
||||
parts: {
|
||||
text: string;
|
||||
}[];
|
||||
}[]) {
|
||||
const config = {
|
||||
model: env.GEMINI_MODEL,
|
||||
config: {
|
||||
|
|
@ -37,16 +42,14 @@ async function generateAIResponse(parsedConversation: string) {
|
|||
],
|
||||
},
|
||||
{
|
||||
role: "user" as const,
|
||||
role: "model" as const,
|
||||
parts: [
|
||||
{
|
||||
text:
|
||||
`Below is the yaml for the current conversation. The last message is the one to respond to. The post is the current one you are meant to be analyzing.
|
||||
|
||||
${parsedConversation}`,
|
||||
text: parsedContext,
|
||||
},
|
||||
],
|
||||
},
|
||||
...messages,
|
||||
];
|
||||
|
||||
let inference = await c.ai.models.generateContent({
|
||||
|
|
@ -99,19 +102,6 @@ ${parsedConversation}`,
|
|||
return inference;
|
||||
}
|
||||
|
||||
async function sendResponse(
|
||||
conversation: Conversation,
|
||||
text: string,
|
||||
): Promise<void> {
|
||||
if (exceedsGraphemes(text)) {
|
||||
multipartResponse(conversation, text);
|
||||
} else {
|
||||
conversation.sendMessage({
|
||||
text,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
export async function handler(message: ChatMessage): Promise<void> {
|
||||
const conversation = await message.getConversation();
|
||||
// ? Conversation should always be able to be found, but just in case:
|
||||
|
|
@ -132,27 +122,26 @@ export async function handler(message: ChatMessage): Promise<void> {
|
|||
return;
|
||||
}
|
||||
|
||||
const today = new Date();
|
||||
today.setHours(0, 0, 0, 0);
|
||||
const tomorrow = new Date(today);
|
||||
tomorrow.setDate(tomorrow.getDate() + 1);
|
||||
if (message.senderDid != env.ADMIN_DID) {
|
||||
const todayStart = new Date();
|
||||
todayStart.setHours(0, 0, 0, 0);
|
||||
|
||||
const dailyCount = await db
|
||||
.select({ count: count(messages.id) })
|
||||
.from(messages)
|
||||
.where(
|
||||
and(
|
||||
eq(messages.did, message.senderDid),
|
||||
gte(messages.created_at, today),
|
||||
lt(messages.created_at, tomorrow),
|
||||
),
|
||||
);
|
||||
const dailyCount = await db
|
||||
.select({ count: count(messages.id) })
|
||||
.from(messages)
|
||||
.where(
|
||||
and(
|
||||
eq(messages.did, message.senderDid),
|
||||
gte(messages.created_at, todayStart),
|
||||
),
|
||||
);
|
||||
|
||||
if (dailyCount[0]!.count >= env.DAILY_QUERY_LIMIT) {
|
||||
conversation.sendMessage({
|
||||
text: c.QUOTA_EXCEEDED_MESSAGE,
|
||||
});
|
||||
return;
|
||||
if (dailyCount[0]!.count >= env.DAILY_QUERY_LIMIT) {
|
||||
conversation.sendMessage({
|
||||
text: c.QUOTA_EXCEEDED_MESSAGE,
|
||||
});
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
logger.success("Found conversation");
|
||||
|
|
@ -160,12 +149,13 @@ export async function handler(message: ChatMessage): Promise<void> {
|
|||
text: "...",
|
||||
});
|
||||
|
||||
const parsedConversation = await parseConversation(conversation);
|
||||
|
||||
logger.info("Parsed conversation: ", parsedConversation);
|
||||
const parsedConversation = await parseConversation(conversation, message);
|
||||
|
||||
try {
|
||||
const inference = await generateAIResponse(parsedConversation);
|
||||
const inference = await generateAIResponse(
|
||||
parsedConversation.context,
|
||||
parsedConversation.messages,
|
||||
);
|
||||
if (!inference) {
|
||||
throw new Error("Failed to generate text. Returned undefined.");
|
||||
}
|
||||
|
|
@ -176,7 +166,13 @@ export async function handler(message: ChatMessage): Promise<void> {
|
|||
logger.success("Generated text:", inference.text);
|
||||
saveMessage(conversation, env.DID, inference.text!);
|
||||
|
||||
await sendResponse(conversation, responseText);
|
||||
if (exceedsGraphemes(responseText)) {
|
||||
multipartResponse(conversation, responseText);
|
||||
} else {
|
||||
conversation.sendMessage({
|
||||
text: responseText,
|
||||
});
|
||||
}
|
||||
}
|
||||
} catch (error) {
|
||||
logger.error("Error in post handler:", error);
|
||||
|
|
|
|||
|
|
@ -14,9 +14,6 @@ import { parsePost, parsePostImages, traverseThread } from "./post";
|
|||
/*
|
||||
Utilities
|
||||
*/
|
||||
const resolveDid = (convo: Conversation, did: string) =>
|
||||
convo.members.find((actor) => actor.did == did)!;
|
||||
|
||||
const getUserDid = (convo: Conversation) =>
|
||||
convo.members.find((actor) => actor.did != env.DID)!;
|
||||
|
||||
|
|
@ -29,16 +26,9 @@ function generateRevision(bytes = 8) {
|
|||
/*
|
||||
Conversations
|
||||
*/
|
||||
async function initConvo(convo: Conversation) {
|
||||
async function initConvo(convo: Conversation, initialMessage: ChatMessage) {
|
||||
const user = getUserDid(convo);
|
||||
|
||||
const initialMessage = (await convo.getMessages()).messages[0] as
|
||||
| ChatMessage
|
||||
| undefined;
|
||||
if (!initialMessage) {
|
||||
throw new Error("Failed to get initial message of conversation");
|
||||
}
|
||||
|
||||
const postUri = await parseMessagePostUri(initialMessage);
|
||||
if (!postUri) {
|
||||
convo.sendMessage({
|
||||
|
|
@ -87,14 +77,14 @@ async function getConvo(convoId: string) {
|
|||
return convo;
|
||||
}
|
||||
|
||||
export async function parseConversation(convo: Conversation) {
|
||||
export async function parseConversation(
|
||||
convo: Conversation,
|
||||
latestMessage: ChatMessage,
|
||||
) {
|
||||
let row = await getConvo(convo.id);
|
||||
if (!row) {
|
||||
row = await initConvo(convo);
|
||||
row = await initConvo(convo, latestMessage);
|
||||
} else {
|
||||
const latestMessage = (await convo.getMessages())
|
||||
.messages[0] as ChatMessage;
|
||||
|
||||
const postUri = await parseMessagePostUri(latestMessage);
|
||||
if (postUri) {
|
||||
const [updatedRow] = await db
|
||||
|
|
@ -128,19 +118,23 @@ export async function parseConversation(convo: Conversation) {
|
|||
|
||||
let parseResult = null;
|
||||
try {
|
||||
parseResult = yaml.dump({
|
||||
post: await parsePost(post, true),
|
||||
parseResult = {
|
||||
context: yaml.dump({
|
||||
post: await parsePost(post, true),
|
||||
}),
|
||||
messages: convoMessages.map((message) => {
|
||||
const profile = resolveDid(convo, message.did);
|
||||
const role = message.did == env.DID ? "model" : "user";
|
||||
|
||||
return {
|
||||
user: profile.displayName
|
||||
? `${profile.displayName} (${profile.handle})`
|
||||
: `Handle: ${profile.handle}`,
|
||||
text: message.text,
|
||||
role,
|
||||
parts: [
|
||||
{
|
||||
text: message.text,
|
||||
},
|
||||
],
|
||||
};
|
||||
}),
|
||||
});
|
||||
};
|
||||
} catch (e) {
|
||||
convo.sendMessage({
|
||||
text:
|
||||
|
|
@ -169,7 +163,8 @@ async function getRelevantMessages(convo: typeof conversations.$inferSelect) {
|
|||
.where(
|
||||
and(
|
||||
eq(messages.conversationId, convo.id),
|
||||
eq(messages.postUri, convo!.postUri),
|
||||
eq(messages.postUri, convo.postUri),
|
||||
eq(messages.revision, convo.revision),
|
||||
),
|
||||
)
|
||||
.limit(15);
|
||||
|
|
@ -192,7 +187,7 @@ export async function saveMessage(
|
|||
.values({
|
||||
conversationId: _convo.id,
|
||||
postUri: _convo.postUri,
|
||||
revision: _convo.postUri,
|
||||
revision: _convo.revision,
|
||||
did,
|
||||
text,
|
||||
});
|
||||
|
|
|
|||
Loading…
Reference in a new issue