{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Function Calling with FunctionGemma\n", "\n", "This notebook demonstrates how to use [Google's FunctionGemma](https://huggingface.co/google/functiongemma-270m-it) for function calling with MLX-LM. We'll cover different use cases including datetime, weather, calendar, email, database queries, and more.\n", "\n", "## Requirements\n", "\n", "```bash\n", "pip install -U mlx-lm\n", "```" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Setup" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import json\n", "import re\n", "from datetime import datetime\n", "from mlx_lm import load, generate" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Load the model\n", "model, tokenizer = load(\"mlx-community/functiongemma-270m-it-bf16\")\n", "print(\"āœ… Model loaded\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Function Definitions\n", "\n", "Below are mock implementations of various functions that demonstrate different use cases. In production, these would connect to real APIs and services." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Use Case 0: System DateTime\n", "def get_current_datetime() -> dict:\n", " \"\"\"Get current system date and time.\"\"\"\n", " now = datetime.now()\n", " return {\n", " \"datetime\": now.isoformat(),\n", " \"date\": now.strftime(\"%Y-%m-%d\"),\n", " \"time\": now.strftime(\"%H:%M:%S\"),\n", " \"day_of_week\": now.strftime(\"%A\"),\n", " \"timezone\": now.astimezone().tzname()\n", " }\n", "\n", "datetime_schema = {\n", " \"type\": \"function\",\n", " \"function\": {\n", " \"name\": \"get_current_datetime\",\n", " \"description\": \"Gets the current system date and time\",\n", " \"parameters\": {\n", " \"type\": \"OBJECT\",\n", " \"properties\": {},\n", " \"required\": [],\n", " },\n", " }\n", "}" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Use Case 1: Weather\n", "def get_current_temperature(location: str) -> dict:\n", " \"\"\"Get current temperature for a location.\"\"\"\n", " temps = {\n", " \"London\": 15, \"Paris\": 18, \"Tokyo\": 22,\n", " \"New York\": 12, \"Sydney\": 25\n", " }\n", " return {\n", " \"location\": location,\n", " \"temperature\": temps.get(location, 20),\n", " \"unit\": \"celsius\"\n", " }\n", "\n", "weather_schema = {\n", " \"type\": \"function\",\n", " \"function\": {\n", " \"name\": \"get_current_temperature\",\n", " \"description\": \"Gets the current temperature for a given location.\",\n", " \"parameters\": {\n", " \"type\": \"OBJECT\",\n", " \"properties\": {\n", " \"location\": {\n", " \"type\": \"STRING\",\n", " \"description\": \"The city name, e.g. San Francisco\",\n", " },\n", " },\n", " \"required\": [\"location\"],\n", " },\n", " }\n", "}" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Use Case 2: Calendar\n", "def create_calendar_event(title: str, date: str, time: str) -> dict:\n", " \"\"\"Create a calendar event.\"\"\"\n", " return {\n", " \"success\": True,\n", " \"event\": {\n", " \"title\": title,\n", " \"date\": date,\n", " \"time\": time,\n", " \"event_id\": f\"evt_{hash(title+date)}\"\n", " }\n", " }\n", "\n", "calendar_schema = {\n", " \"type\": \"function\",\n", " \"function\": {\n", " \"name\": \"create_calendar_event\",\n", " \"description\": \"Creates a new event in the calendar\",\n", " \"parameters\": {\n", " \"type\": \"OBJECT\",\n", " \"properties\": {\n", " \"title\": {\"type\": \"STRING\", \"description\": \"Event title\"},\n", " \"date\": {\"type\": \"STRING\", \"description\": \"Event date (YYYY-MM-DD)\"},\n", " \"time\": {\"type\": \"STRING\", \"description\": \"Event time (HH:MM)\"}\n", " },\n", " \"required\": [\"title\", \"date\", \"time\"],\n", " },\n", " }\n", "}" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Use Case 3: Email\n", "def send_email(recipient: str, subject: str, body: str) -> dict:\n", " \"\"\"Send an email.\"\"\"\n", " return {\n", " \"success\": True,\n", " \"message_id\": f\"msg_{hash(recipient+subject)}\",\n", " \"sent_to\": recipient\n", " }\n", "\n", "email_schema = {\n", " \"type\": \"function\",\n", " \"function\": {\n", " \"name\": \"send_email\",\n", " \"description\": \"Sends an email to a recipient\",\n", " \"parameters\": {\n", " \"type\": \"OBJECT\",\n", " \"properties\": {\n", " \"recipient\": {\"type\": \"STRING\", \"description\": \"Email address\"},\n", " \"subject\": {\"type\": \"STRING\", \"description\": \"Email subject\"},\n", " \"body\": {\"type\": \"STRING\", \"description\": \"Email body\"}\n", " },\n", " \"required\": [\"recipient\", \"subject\", \"body\"],\n", " },\n", " }\n", "}" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Use Case 4: Database\n", "def search_database(query: str, table: str) -> dict:\n", " \"\"\"Search database.\"\"\"\n", " return {\n", " \"results\": [\n", " {\"id\": 1, \"name\": \"Result 1\"},\n", " {\"id\": 2, \"name\": \"Result 2\"}\n", " ],\n", " \"count\": 2,\n", " \"table\": table\n", " }\n", "\n", "database_schema = {\n", " \"type\": \"function\",\n", " \"function\": {\n", " \"name\": \"search_database\",\n", " \"description\": \"Searches a database table\",\n", " \"parameters\": {\n", " \"type\": \"OBJECT\",\n", " \"properties\": {\n", " \"query\": {\"type\": \"STRING\", \"description\": \"Search query\"},\n", " \"table\": {\"type\": \"STRING\", \"description\": \"Table name\"}\n", " },\n", " \"required\": [\"query\", \"table\"],\n", " },\n", " }\n", "}" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Use Case 5: File System\n", "def list_files(directory: str) -> dict:\n", " \"\"\"List files in directory.\"\"\"\n", " return {\n", " \"files\": [\"file1.txt\", \"file2.py\", \"file3.md\"],\n", " \"count\": 3,\n", " \"directory\": directory\n", " }\n", "\n", "filesystem_schema = {\n", " \"type\": \"function\",\n", " \"function\": {\n", " \"name\": \"list_files\",\n", " \"description\": \"Lists files in a directory\",\n", " \"parameters\": {\n", " \"type\": \"OBJECT\",\n", " \"properties\": {\n", " \"directory\": {\"type\": \"STRING\", \"description\": \"Directory path\"}\n", " },\n", " \"required\": [\"directory\"],\n", " },\n", " }\n", "}" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Use Case 6: Translation\n", "def translate_text(text: str, target_language: str) -> dict:\n", " \"\"\"Translate text.\"\"\"\n", " return {\n", " \"original\": text,\n", " \"translated\": f\"[{target_language.upper()}] {text}\",\n", " \"language\": target_language\n", " }\n", "\n", "translation_schema = {\n", " \"type\": \"function\",\n", " \"function\": {\n", " \"name\": \"translate_text\",\n", " \"description\": \"Translates text to another language\",\n", " \"parameters\": {\n", " \"type\": \"OBJECT\",\n", " \"properties\": {\n", " \"text\": {\"type\": \"STRING\", \"description\": \"Text to translate\"},\n", " \"target_language\": {\"type\": \"STRING\", \"description\": \"Target language (e.g., Spanish, French)\"}\n", " },\n", " \"required\": [\"text\", \"target_language\"],\n", " },\n", " }\n", "}" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Use Case 7: Calculator\n", "def calculate(expression: str) -> dict:\n", " \"\"\"Calculate expression.\"\"\"\n", " try:\n", " result = eval(expression)\n", " return {\"expression\": expression, \"result\": result, \"success\": True}\n", " except Exception as e:\n", " return {\"expression\": expression, \"error\": str(e), \"success\": False}\n", "\n", "calculator_schema = {\n", " \"type\": \"function\",\n", " \"function\": {\n", " \"name\": \"calculate\",\n", " \"description\": \"Evaluates a mathematical expression\",\n", " \"parameters\": {\n", " \"type\": \"OBJECT\",\n", " \"properties\": {\n", " \"expression\": {\"type\": \"STRING\", \"description\": \"Math expression (e.g., '2+2', '10*5')\"}\n", " },\n", " \"required\": [\"expression\"],\n", " },\n", " }\n", "}" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Use Case 8: Timer\n", "def set_timer(duration_minutes: int, label: str) -> dict:\n", " \"\"\"Set a timer.\"\"\"\n", " return {\n", " \"timer_id\": f\"timer_{hash(label)}\",\n", " \"duration_minutes\": duration_minutes,\n", " \"label\": label,\n", " \"status\": \"active\"\n", " }\n", "\n", "timer_schema = {\n", " \"type\": \"function\",\n", " \"function\": {\n", " \"name\": \"set_timer\",\n", " \"description\": \"Sets a timer with specified duration\",\n", " \"parameters\": {\n", " \"type\": \"OBJECT\",\n", " \"properties\": {\n", " \"duration_minutes\": {\"type\": \"NUMBER\", \"description\": \"Duration in minutes\"},\n", " \"label\": {\"type\": \"STRING\", \"description\": \"Timer label\"}\n", " },\n", " \"required\": [\"duration_minutes\", \"label\"],\n", " },\n", " }\n", "}" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Use Case 9: News\n", "def get_news(category: str, limit: int) -> dict:\n", " \"\"\"Get news articles.\"\"\"\n", " return {\n", " \"category\": category,\n", " \"articles\": [f\"Article {i+1}\" for i in range(limit)],\n", " \"count\": limit\n", " }\n", "\n", "news_schema = {\n", " \"type\": \"function\",\n", " \"function\": {\n", " \"name\": \"get_news\",\n", " \"description\": \"Gets news articles by category\",\n", " \"parameters\": {\n", " \"type\": \"OBJECT\",\n", " \"properties\": {\n", " \"category\": {\"type\": \"STRING\", \"description\": \"News category (e.g., technology, sports)\"},\n", " \"limit\": {\"type\": \"NUMBER\", \"description\": \"Number of articles\"}\n", " },\n", " \"required\": [\"category\", \"limit\"],\n", " },\n", " }\n", "}" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Use Case 10: Reminder\n", "def create_reminder(task: str, reminder_time: str) -> dict:\n", " \"\"\"Create a reminder.\"\"\"\n", " return {\n", " \"reminder_id\": f\"rem_{hash(task)}\",\n", " \"task\": task,\n", " \"reminder_time\": reminder_time,\n", " \"status\": \"scheduled\"\n", " }\n", "\n", "reminder_schema = {\n", " \"type\": \"function\",\n", " \"function\": {\n", " \"name\": \"create_reminder\",\n", " \"description\": \"Creates a reminder for a task\",\n", " \"parameters\": {\n", " \"type\": \"OBJECT\",\n", " \"properties\": {\n", " \"task\": {\"type\": \"STRING\", \"description\": \"Task description\"},\n", " \"reminder_time\": {\"type\": \"STRING\", \"description\": \"Time for reminder (e.g., '2pm tomorrow')\"}\n", " },\n", " \"required\": [\"task\", \"reminder_time\"],\n", " },\n", " }\n", "}" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Function & Schema Registry" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "ALL_FUNCTIONS = {\n", " \"get_current_datetime\": get_current_datetime,\n", " \"get_current_temperature\": get_current_temperature,\n", " \"create_calendar_event\": create_calendar_event,\n", " \"send_email\": send_email,\n", " \"search_database\": search_database,\n", " \"list_files\": list_files,\n", " \"translate_text\": translate_text,\n", " \"calculate\": calculate,\n", " \"set_timer\": set_timer,\n", " \"get_news\": get_news,\n", " \"create_reminder\": create_reminder\n", "}\n", "\n", "ALL_SCHEMAS = {\n", " \"datetime\": datetime_schema,\n", " \"weather\": weather_schema,\n", " \"calendar\": calendar_schema,\n", " \"email\": email_schema,\n", " \"database\": database_schema,\n", " \"filesystem\": filesystem_schema,\n", " \"translation\": translation_schema,\n", " \"calculator\": calculator_schema,\n", " \"timer\": timer_schema,\n", " \"news\": news_schema,\n", " \"reminder\": reminder_schema\n", "}" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Response Parser\n", "\n", "FunctionGemma uses a specific format for function calls:\n", "```\n", "call:function_name{key:value,key:value}\n", "```\n", "\n", "String values may be wrapped in `` tags." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def parse_function_call(response: str):\n", " \"\"\"Parse FunctionGemma's response format.\"\"\"\n", " if \"\" not in response:\n", " return None, None\n", "\n", " start = response.find(\"\") + len(\"\")\n", " end = response.find(\"\")\n", "\n", " if end == -1:\n", " call_str = response[start:].strip()\n", " else:\n", " call_str = response[start:end].strip()\n", "\n", " # Parse: call:function_name{key:value,key:value}\n", " match = re.match(r'call:([\\w_]+)\\{([^}]*)\\}?', call_str, re.DOTALL)\n", " if not match:\n", " return None, None\n", "\n", " func_name = match.group(1)\n", " args_str = match.group(2).strip()\n", "\n", " # Handle empty parameters\n", " if not args_str:\n", " return func_name, {}\n", "\n", " # Process escape tags\n", " temp = args_str\n", " while '' in temp:\n", " start_idx = temp.find('')\n", " end_idx = temp.find('', start_idx + 8)\n", " if end_idx != -1:\n", " escaped_content = temp[start_idx+8:end_idx]\n", " temp = temp[:start_idx] + escaped_content + temp[end_idx+8:]\n", " else:\n", " temp = temp.replace('', '')\n", " break\n", "\n", " # Parse key:value pairs\n", " args = {}\n", " for pair in temp.split(','):\n", " if ':' in pair:\n", " key, value = pair.split(':', 1)\n", " key = key.strip()\n", " value = value.strip()\n", "\n", " # Type conversion\n", " try:\n", " if '.' in value:\n", " args[key] = float(value)\n", " else:\n", " args[key] = int(value)\n", " except ValueError:\n", " args[key] = value.strip('\"\\'')\n", "\n", " return func_name, args" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Test Runner" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def run_function_call(query: str, schema: dict, verbose: bool = True) -> dict:\n", " \"\"\"\n", " Run a function call with FunctionGemma.\n", "\n", " Args:\n", " query: Natural language query from the user\n", " schema: Function schema to make available\n", " verbose: Whether to print detailed output\n", "\n", " Returns:\n", " dict with 'success', 'function', 'args', and 'result' keys\n", " \"\"\"\n", " current_dt = get_current_datetime()\n", "\n", " # IMPORTANT: Developer role activates function calling\n", " messages = [\n", " {\n", " \"role\": \"developer\",\n", " \"content\": f\"\"\"You are a model that can do function calling with the following functions.\n", "\n", "Current system information:\n", "- Date: {current_dt['date']} ({current_dt['day_of_week']})\n", "- Time: {current_dt['time']}\n", "- Timezone: {current_dt['timezone']}\"\"\"\n", " },\n", " {\n", " \"role\": \"user\",\n", " \"content\": query\n", " }\n", " ]\n", "\n", " # Apply chat template with tools\n", " prompt = tokenizer.apply_chat_template(\n", " messages,\n", " tools=[schema],\n", " add_generation_prompt=True,\n", " tokenize=False\n", " )\n", "\n", " # Generate response\n", " response = generate(model, tokenizer, prompt=prompt, max_tokens=1024, verbose=False)\n", "\n", " if verbose:\n", " print(f\"Query: {query}\")\n", " print(f\"Response: {response}\")\n", "\n", " # Parse and execute\n", " func_name, func_args = parse_function_call(response)\n", "\n", " if func_name and func_args is not None:\n", " if verbose:\n", " print(f\"\\nāœ… Function: {func_name}()\")\n", " print(f\" Arguments: {json.dumps(func_args, indent=2)}\")\n", "\n", " if func_name in ALL_FUNCTIONS:\n", " result = ALL_FUNCTIONS[func_name](**func_args)\n", " if verbose:\n", " print(f\" Result: {json.dumps(result, indent=2)}\")\n", " return {\"success\": True, \"function\": func_name, \"args\": func_args, \"result\": result}\n", "\n", " if verbose:\n", " print(\"\\nāŒ No function call detected\")\n", " return {\"success\": False, \"function\": None, \"args\": None, \"result\": None}" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Examples\n", "\n", "Let's test each use case:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# DateTime\n", "run_function_call(\"What is the current date and time?\", datetime_schema)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Weather\n", "run_function_call(\"What's the temperature in London?\", weather_schema)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Calendar\n", "run_function_call(\"Create a meeting event for tomorrow at 2pm called Team Sync\", calendar_schema)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Email\n", "run_function_call(\n", " \"Send an email to john@example.com with subject 'Meeting' and body 'See you tomorrow'\",\n", " email_schema\n", ")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Database\n", "run_function_call(\"Search the users table for active accounts\", database_schema)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# File System\n", "run_function_call(\"List files in the /home/user directory\", filesystem_schema)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Translation\n", "run_function_call(\"Translate 'Hello World' to Spanish\", translation_schema)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Calculator\n", "run_function_call(\"Calculate 25 * 4 + 10\", calculator_schema)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Timer\n", "run_function_call(\"Set a 5 minute timer for coffee\", timer_schema)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# News\n", "run_function_call(\"Get 3 technology news articles\", news_schema)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Reminder\n", "run_function_call(\"Remind me to call mom at 6pm today\", reminder_schema)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Run All Tests" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "test_cases = [\n", " (\"datetime\", datetime_schema, \"What is the current date and time?\"),\n", " (\"weather\", weather_schema, \"What's the temperature in London?\"),\n", " (\"calendar\", calendar_schema, \"Create a meeting event for tomorrow at 2pm called Team Sync\"),\n", " (\"email\", email_schema, \"Send an email to john@example.com with subject 'Meeting' and body 'See you tomorrow'\"),\n", " (\"database\", database_schema, \"Search the users table for active accounts\"),\n", " (\"filesystem\", filesystem_schema, \"List files in the /home/user directory\"),\n", " (\"translation\", translation_schema, \"Translate 'Hello World' to Spanish\"),\n", " (\"calculator\", calculator_schema, \"Calculate 25 * 4 + 10\"),\n", " (\"timer\", timer_schema, \"Set a 5 minute timer for coffee\"),\n", " (\"news\", news_schema, \"Get 3 technology news articles\"),\n", " (\"reminder\", reminder_schema, \"Remind me to call mom at 6pm today\")\n", "]\n", "\n", "successful = 0\n", "for name, schema, query in test_cases:\n", " print(f\"\\n{'='*60}\")\n", " print(f\"USE CASE: {name.upper()}\")\n", " print(f\"{'='*60}\")\n", " result = run_function_call(query, schema)\n", " if result[\"success\"]:\n", " successful += 1\n", "\n", "print(f\"\\n\\n{'='*60}\")\n", "print(f\"SUMMARY: {successful}/{len(test_cases)} tests passed ({successful/len(test_cases)*100:.1f}%)\")\n", "print(f\"{'='*60}\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Key Takeaways\n", "\n", "1. **Developer Role**: Use `role: \"developer\"` to activate function calling mode\n", "2. **Schema Format**: Follow the OBJECT/STRING/NUMBER type format in schemas\n", "3. **Response Format**: FunctionGemma uses `call:name{args}`\n", "4. **Escaped Strings**: String values may be wrapped in `` tags\n", "5. **Context**: Providing current datetime helps with time-aware queries" ] } ], "metadata": { "kernelspec": { "display_name": "mlx", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.12" } }, "nbformat": 4, "nbformat_minor": 4 }