update inference structure (closes #3)
This commit is contained in:
parent
dda2552126
commit
2fbf412df4
3 changed files with 61 additions and 69 deletions
|
|
@ -12,5 +12,6 @@ services:
|
||||||
- "HANDLE=${HANDLE:?}"
|
- "HANDLE=${HANDLE:?}"
|
||||||
- "APP_PASSWORD=${APP_PASSWORD:?}"
|
- "APP_PASSWORD=${APP_PASSWORD:?}"
|
||||||
- "GEMINI_API_KEY=${GEMINI_API_KEY:?}"
|
- "GEMINI_API_KEY=${GEMINI_API_KEY:?}"
|
||||||
|
- "USE_JETSTREAM=${USE_JETSTREAM:-false}"
|
||||||
volumes:
|
volumes:
|
||||||
- aero_db:/sqlite.db
|
- aero_db:/sqlite.db
|
||||||
|
|
|
||||||
|
|
@ -18,7 +18,12 @@ const logger = consola.withTag("Message Handler");
|
||||||
|
|
||||||
type SupportedFunctionCall = typeof c.SUPPORTED_FUNCTION_CALLS[number];
|
type SupportedFunctionCall = typeof c.SUPPORTED_FUNCTION_CALLS[number];
|
||||||
|
|
||||||
async function generateAIResponse(parsedConversation: string) {
|
async function generateAIResponse(parsedContext: string, messages: {
|
||||||
|
role: string;
|
||||||
|
parts: {
|
||||||
|
text: string;
|
||||||
|
}[];
|
||||||
|
}[]) {
|
||||||
const config = {
|
const config = {
|
||||||
model: env.GEMINI_MODEL,
|
model: env.GEMINI_MODEL,
|
||||||
config: {
|
config: {
|
||||||
|
|
@ -37,16 +42,14 @@ async function generateAIResponse(parsedConversation: string) {
|
||||||
],
|
],
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
role: "user" as const,
|
role: "model" as const,
|
||||||
parts: [
|
parts: [
|
||||||
{
|
{
|
||||||
text:
|
text: parsedContext,
|
||||||
`Below is the yaml for the current conversation. The last message is the one to respond to. The post is the current one you are meant to be analyzing.
|
|
||||||
|
|
||||||
${parsedConversation}`,
|
|
||||||
},
|
},
|
||||||
],
|
],
|
||||||
},
|
},
|
||||||
|
...messages,
|
||||||
];
|
];
|
||||||
|
|
||||||
let inference = await c.ai.models.generateContent({
|
let inference = await c.ai.models.generateContent({
|
||||||
|
|
@ -99,19 +102,6 @@ ${parsedConversation}`,
|
||||||
return inference;
|
return inference;
|
||||||
}
|
}
|
||||||
|
|
||||||
async function sendResponse(
|
|
||||||
conversation: Conversation,
|
|
||||||
text: string,
|
|
||||||
): Promise<void> {
|
|
||||||
if (exceedsGraphemes(text)) {
|
|
||||||
multipartResponse(conversation, text);
|
|
||||||
} else {
|
|
||||||
conversation.sendMessage({
|
|
||||||
text,
|
|
||||||
});
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
export async function handler(message: ChatMessage): Promise<void> {
|
export async function handler(message: ChatMessage): Promise<void> {
|
||||||
const conversation = await message.getConversation();
|
const conversation = await message.getConversation();
|
||||||
// ? Conversation should always be able to be found, but just in case:
|
// ? Conversation should always be able to be found, but just in case:
|
||||||
|
|
@ -132,27 +122,26 @@ export async function handler(message: ChatMessage): Promise<void> {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
const today = new Date();
|
if (message.senderDid != env.ADMIN_DID) {
|
||||||
today.setHours(0, 0, 0, 0);
|
const todayStart = new Date();
|
||||||
const tomorrow = new Date(today);
|
todayStart.setHours(0, 0, 0, 0);
|
||||||
tomorrow.setDate(tomorrow.getDate() + 1);
|
|
||||||
|
|
||||||
const dailyCount = await db
|
const dailyCount = await db
|
||||||
.select({ count: count(messages.id) })
|
.select({ count: count(messages.id) })
|
||||||
.from(messages)
|
.from(messages)
|
||||||
.where(
|
.where(
|
||||||
and(
|
and(
|
||||||
eq(messages.did, message.senderDid),
|
eq(messages.did, message.senderDid),
|
||||||
gte(messages.created_at, today),
|
gte(messages.created_at, todayStart),
|
||||||
lt(messages.created_at, tomorrow),
|
),
|
||||||
),
|
);
|
||||||
);
|
|
||||||
|
|
||||||
if (dailyCount[0]!.count >= env.DAILY_QUERY_LIMIT) {
|
if (dailyCount[0]!.count >= env.DAILY_QUERY_LIMIT) {
|
||||||
conversation.sendMessage({
|
conversation.sendMessage({
|
||||||
text: c.QUOTA_EXCEEDED_MESSAGE,
|
text: c.QUOTA_EXCEEDED_MESSAGE,
|
||||||
});
|
});
|
||||||
return;
|
return;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
logger.success("Found conversation");
|
logger.success("Found conversation");
|
||||||
|
|
@ -160,12 +149,13 @@ export async function handler(message: ChatMessage): Promise<void> {
|
||||||
text: "...",
|
text: "...",
|
||||||
});
|
});
|
||||||
|
|
||||||
const parsedConversation = await parseConversation(conversation);
|
const parsedConversation = await parseConversation(conversation, message);
|
||||||
|
|
||||||
logger.info("Parsed conversation: ", parsedConversation);
|
|
||||||
|
|
||||||
try {
|
try {
|
||||||
const inference = await generateAIResponse(parsedConversation);
|
const inference = await generateAIResponse(
|
||||||
|
parsedConversation.context,
|
||||||
|
parsedConversation.messages,
|
||||||
|
);
|
||||||
if (!inference) {
|
if (!inference) {
|
||||||
throw new Error("Failed to generate text. Returned undefined.");
|
throw new Error("Failed to generate text. Returned undefined.");
|
||||||
}
|
}
|
||||||
|
|
@ -176,7 +166,13 @@ export async function handler(message: ChatMessage): Promise<void> {
|
||||||
logger.success("Generated text:", inference.text);
|
logger.success("Generated text:", inference.text);
|
||||||
saveMessage(conversation, env.DID, inference.text!);
|
saveMessage(conversation, env.DID, inference.text!);
|
||||||
|
|
||||||
await sendResponse(conversation, responseText);
|
if (exceedsGraphemes(responseText)) {
|
||||||
|
multipartResponse(conversation, responseText);
|
||||||
|
} else {
|
||||||
|
conversation.sendMessage({
|
||||||
|
text: responseText,
|
||||||
|
});
|
||||||
|
}
|
||||||
}
|
}
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
logger.error("Error in post handler:", error);
|
logger.error("Error in post handler:", error);
|
||||||
|
|
|
||||||
|
|
@ -14,9 +14,6 @@ import { parsePost, parsePostImages, traverseThread } from "./post";
|
||||||
/*
|
/*
|
||||||
Utilities
|
Utilities
|
||||||
*/
|
*/
|
||||||
const resolveDid = (convo: Conversation, did: string) =>
|
|
||||||
convo.members.find((actor) => actor.did == did)!;
|
|
||||||
|
|
||||||
const getUserDid = (convo: Conversation) =>
|
const getUserDid = (convo: Conversation) =>
|
||||||
convo.members.find((actor) => actor.did != env.DID)!;
|
convo.members.find((actor) => actor.did != env.DID)!;
|
||||||
|
|
||||||
|
|
@ -29,16 +26,9 @@ function generateRevision(bytes = 8) {
|
||||||
/*
|
/*
|
||||||
Conversations
|
Conversations
|
||||||
*/
|
*/
|
||||||
async function initConvo(convo: Conversation) {
|
async function initConvo(convo: Conversation, initialMessage: ChatMessage) {
|
||||||
const user = getUserDid(convo);
|
const user = getUserDid(convo);
|
||||||
|
|
||||||
const initialMessage = (await convo.getMessages()).messages[0] as
|
|
||||||
| ChatMessage
|
|
||||||
| undefined;
|
|
||||||
if (!initialMessage) {
|
|
||||||
throw new Error("Failed to get initial message of conversation");
|
|
||||||
}
|
|
||||||
|
|
||||||
const postUri = await parseMessagePostUri(initialMessage);
|
const postUri = await parseMessagePostUri(initialMessage);
|
||||||
if (!postUri) {
|
if (!postUri) {
|
||||||
convo.sendMessage({
|
convo.sendMessage({
|
||||||
|
|
@ -87,14 +77,14 @@ async function getConvo(convoId: string) {
|
||||||
return convo;
|
return convo;
|
||||||
}
|
}
|
||||||
|
|
||||||
export async function parseConversation(convo: Conversation) {
|
export async function parseConversation(
|
||||||
|
convo: Conversation,
|
||||||
|
latestMessage: ChatMessage,
|
||||||
|
) {
|
||||||
let row = await getConvo(convo.id);
|
let row = await getConvo(convo.id);
|
||||||
if (!row) {
|
if (!row) {
|
||||||
row = await initConvo(convo);
|
row = await initConvo(convo, latestMessage);
|
||||||
} else {
|
} else {
|
||||||
const latestMessage = (await convo.getMessages())
|
|
||||||
.messages[0] as ChatMessage;
|
|
||||||
|
|
||||||
const postUri = await parseMessagePostUri(latestMessage);
|
const postUri = await parseMessagePostUri(latestMessage);
|
||||||
if (postUri) {
|
if (postUri) {
|
||||||
const [updatedRow] = await db
|
const [updatedRow] = await db
|
||||||
|
|
@ -128,19 +118,23 @@ export async function parseConversation(convo: Conversation) {
|
||||||
|
|
||||||
let parseResult = null;
|
let parseResult = null;
|
||||||
try {
|
try {
|
||||||
parseResult = yaml.dump({
|
parseResult = {
|
||||||
post: await parsePost(post, true),
|
context: yaml.dump({
|
||||||
|
post: await parsePost(post, true),
|
||||||
|
}),
|
||||||
messages: convoMessages.map((message) => {
|
messages: convoMessages.map((message) => {
|
||||||
const profile = resolveDid(convo, message.did);
|
const role = message.did == env.DID ? "model" : "user";
|
||||||
|
|
||||||
return {
|
return {
|
||||||
user: profile.displayName
|
role,
|
||||||
? `${profile.displayName} (${profile.handle})`
|
parts: [
|
||||||
: `Handle: ${profile.handle}`,
|
{
|
||||||
text: message.text,
|
text: message.text,
|
||||||
|
},
|
||||||
|
],
|
||||||
};
|
};
|
||||||
}),
|
}),
|
||||||
});
|
};
|
||||||
} catch (e) {
|
} catch (e) {
|
||||||
convo.sendMessage({
|
convo.sendMessage({
|
||||||
text:
|
text:
|
||||||
|
|
@ -169,7 +163,8 @@ async function getRelevantMessages(convo: typeof conversations.$inferSelect) {
|
||||||
.where(
|
.where(
|
||||||
and(
|
and(
|
||||||
eq(messages.conversationId, convo.id),
|
eq(messages.conversationId, convo.id),
|
||||||
eq(messages.postUri, convo!.postUri),
|
eq(messages.postUri, convo.postUri),
|
||||||
|
eq(messages.revision, convo.revision),
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
.limit(15);
|
.limit(15);
|
||||||
|
|
@ -192,7 +187,7 @@ export async function saveMessage(
|
||||||
.values({
|
.values({
|
||||||
conversationId: _convo.id,
|
conversationId: _convo.id,
|
||||||
postUri: _convo.postUri,
|
postUri: _convo.postUri,
|
||||||
revision: _convo.postUri,
|
revision: _convo.revision,
|
||||||
did,
|
did,
|
||||||
text,
|
text,
|
||||||
});
|
});
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue