diff --git a/src/handlers/posts.ts b/src/handlers/posts.ts index 56cf3f2..6c90137 100644 --- a/src/handlers/posts.ts +++ b/src/handlers/posts.ts @@ -83,6 +83,7 @@ ${parsedThread}`, const functionResponse = await tools.handler( call as typeof call & { name: SupportedFunctionCall }, + post.author.did, ); logger.log("Function response:", functionResponse); diff --git a/src/tools/add_to_memory.ts b/src/tools/add_to_memory.ts new file mode 100644 index 0000000..4fd3311 --- /dev/null +++ b/src/tools/add_to_memory.ts @@ -0,0 +1,65 @@ +import { Type } from "@google/genai"; +import { MemoryHandler } from "../utils/memory"; +import z from "zod"; + +export const definition = { + name: "add_to_memory", + description: "Adds or updates an entry in a user's memory block.", + parameters: { + type: Type.OBJECT, + properties: { + label: { + type: Type.STRING, + description: "The key or label for the memory entry.", + }, + value: { + type: Type.STRING, + description: "The value to be stored.", + }, + block: { + type: Type.STRING, + description: "The name of the memory block to add to. Defaults to 'memory'.", + }, + }, + required: ["label", "value"], + }, +}; + +export const validator = z.object({ + label: z.string(), + value: z.string(), + block: z.string().optional().default("memory"), +}); + +export async function handler( + args: z.infer, + did: string, +) { + const userMemory = new MemoryHandler( + did, + await MemoryHandler.getBlocks(did), + ); + + const blockHandler = userMemory.getBlockByName(args.block); + + if (!blockHandler) { + return { + success: false, + message: `Memory block with name '${args.block}' not found.`, + }; + } + + if (!blockHandler.block.mutable) { + return { + success: false, + message: `Memory block '${args.block}' is not mutable.`, + }; + } + + await blockHandler.createEntry(args.label, args.value); + + return { + success: true, + message: `Entry with label '${args.label}' has been added to the '${args.block}' memory block.`, + }; +} diff --git a/src/tools/create_blog_post.ts b/src/tools/create_blog_post.ts index 3ac477e..4409217 100644 --- a/src/tools/create_blog_post.ts +++ b/src/tools/create_blog_post.ts @@ -28,7 +28,10 @@ export const validator = z.object({ content: z.string(), }); -export async function handler(args: z.infer) { +export async function handler( + args: z.infer, + did: string, +) { //@ts-ignore: NSID is valid const entry = await bot.createRecord("com.whtwnd.blog.entry", { $type: "com.whtwnd.blog.entry", diff --git a/src/tools/create_post.ts b/src/tools/create_post.ts index 015a12c..d548894 100644 --- a/src/tools/create_post.ts +++ b/src/tools/create_post.ts @@ -25,7 +25,10 @@ export const validator = z.object({ text: z.string(), }); -export async function handler(args: z.infer) { +export async function handler( + args: z.infer, + did: string, +) { let uri: string | null = null; if (exceedsGraphemes(args.text)) { uri = await multipartResponse(args.text); diff --git a/src/tools/index.ts b/src/tools/index.ts index 9169147..bc8e06f 100644 --- a/src/tools/index.ts +++ b/src/tools/index.ts @@ -1,4 +1,5 @@ import type { FunctionCall, GenerateContentConfig } from "@google/genai"; +import * as add_to_memory from "./add_to_memory"; import * as create_blog_post from "./create_blog_post"; import * as create_post from "./create_post"; import * as mute_thread from "./mute_thread"; @@ -8,6 +9,7 @@ const validation_mappings = { "create_post": create_post.validator, "create_blog_post": create_blog_post.validator, "mute_thread": mute_thread.validator, + "add_to_memory": add_to_memory.validator, } as const; export const declarations = [ @@ -16,26 +18,38 @@ export const declarations = [ create_post.definition, create_blog_post.definition, mute_thread.definition, + add_to_memory.definition, ], }, ]; type ToolName = keyof typeof validation_mappings; -export async function handler(call: FunctionCall & { name: ToolName }) { +export async function handler( + call: FunctionCall & { name: ToolName }, + did: string, +) { const parsedArgs = validation_mappings[call.name].parse(call.args); switch (call.name) { case "create_post": return await create_post.handler( parsedArgs as z_infer, + did, ); case "create_blog_post": return await create_blog_post.handler( parsedArgs as z_infer, + did, ); case "mute_thread": return await mute_thread.handler( parsedArgs as z_infer, + did, + ); + case "add_to_memory": + return await add_to_memory.handler( + parsedArgs as z_infer, + did, ); } } diff --git a/src/tools/mute_thread.ts b/src/tools/mute_thread.ts index 35228b8..673b2bf 100644 --- a/src/tools/mute_thread.ts +++ b/src/tools/mute_thread.ts @@ -25,7 +25,10 @@ export const validator = z.object({ uri: z.string(), }); -export async function handler(args: z.infer) { +export async function handler( + args: z.infer, + did: string, +) { //@ts-ignore: NSID is valid const record = await bot.createRecord("dev.indexx.echo.threadmute", { $type: "dev.indexx.echo.threadmute", diff --git a/src/utils/memory.ts b/src/utils/memory.ts index 25b9214..9929c50 100644 --- a/src/utils/memory.ts +++ b/src/utils/memory.ts @@ -94,6 +94,10 @@ export class MemoryHandler { })), })); } + + public getBlockByName(name: string) { + return this.blocks.find((handler) => handler.block.name === name); + } } export class MemoryBlockHandler {