Get Gemini to work
This commit is contained in:
@@ -20,7 +20,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 63,
|
"execution_count": null,
|
||||||
"id": "7b624d5b-69a2-441f-9147-fde105d3d551",
|
"id": "7b624d5b-69a2-441f-9147-fde105d3d551",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
@@ -38,7 +38,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 64,
|
"execution_count": null,
|
||||||
"id": "a07e7793-b8f5-44f4-aded-5562f633271a",
|
"id": "a07e7793-b8f5-44f4-aded-5562f633271a",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
@@ -57,20 +57,10 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 65,
|
"execution_count": null,
|
||||||
"id": "efb88276-6d74-4d94-95a2-b8ca82a4716c",
|
"id": "efb88276-6d74-4d94-95a2-b8ca82a4716c",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [],
|
||||||
{
|
|
||||||
"name": "stdout",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"OpenAI API Key exists and begins sk-proj-\n",
|
|
||||||
"Anthropic API Key exists and begins sk-ant-a\n",
|
|
||||||
"Google API Key exists and begins AIzaSyAS\n"
|
|
||||||
]
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"source": [
|
"source": [
|
||||||
"# Load environment variables\n",
|
"# Load environment variables\n",
|
||||||
"load_dotenv()\n",
|
"load_dotenv()\n",
|
||||||
@@ -98,7 +88,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 66,
|
"execution_count": null,
|
||||||
"id": "484f0c3e-638d-4af7-bb9b-36faf6048f3c",
|
"id": "484f0c3e-638d-4af7-bb9b-36faf6048f3c",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
@@ -113,7 +103,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 67,
|
"execution_count": null,
|
||||||
"id": "2e292401-e62f-4bfc-b060-07462ad20d3d",
|
"id": "2e292401-e62f-4bfc-b060-07462ad20d3d",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
@@ -128,7 +118,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 68,
|
"execution_count": null,
|
||||||
"id": "84252e03-ccde-4ecf-975b-78227291ca5c",
|
"id": "84252e03-ccde-4ecf-975b-78227291ca5c",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
@@ -142,7 +132,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 69,
|
"execution_count": null,
|
||||||
"id": "49396924-47c2-4f7d-baa2-9b0fece9da4a",
|
"id": "49396924-47c2-4f7d-baa2-9b0fece9da4a",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
@@ -165,7 +155,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 70,
|
"execution_count": null,
|
||||||
"id": "c4d23747-d78a-4f36-9862-c00e1e8d9e44",
|
"id": "c4d23747-d78a-4f36-9862-c00e1e8d9e44",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
@@ -179,7 +169,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 71,
|
"execution_count": null,
|
||||||
"id": "67e150be-502e-4ba4-9586-3a2f3fae3830",
|
"id": "67e150be-502e-4ba4-9586-3a2f3fae3830",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
@@ -214,7 +204,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 72,
|
"execution_count": null,
|
||||||
"id": "bd9d0511-2f78-4270-81f8-73708388dfad",
|
"id": "bd9d0511-2f78-4270-81f8-73708388dfad",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
@@ -239,7 +229,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 73,
|
"execution_count": null,
|
||||||
"id": "47733d5b-bb0a-44dd-b56d-a54677c88f80",
|
"id": "47733d5b-bb0a-44dd-b56d-a54677c88f80",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
@@ -260,24 +250,24 @@
|
|||||||
"\n",
|
"\n",
|
||||||
"# Gemini tool definition must be a FunctionDeclaration object without the top-level `type` in parameters.\n",
|
"# Gemini tool definition must be a FunctionDeclaration object without the top-level `type` in parameters.\n",
|
||||||
"tools_gemini = [google.generativeai.protos.FunctionDeclaration(\n",
|
"tools_gemini = [google.generativeai.protos.FunctionDeclaration(\n",
|
||||||
" name=scraping_function[\"name\"],\n",
|
" name=portable_scraping_function_definition[\"name\"],\n",
|
||||||
" description=scraping_function[\"description\"],\n",
|
" description=portable_scraping_function_definition[\"description\"],\n",
|
||||||
" parameters=google.generativeai.protos.Schema(\n",
|
" parameters=google.generativeai.protos.Schema(\n",
|
||||||
" type=google.generativeai.protos.Type.OBJECT,\n",
|
" type=google.generativeai.protos.Type.OBJECT,\n",
|
||||||
" properties={\n",
|
" properties={\n",
|
||||||
" \"text\": google.generativeai.protos.Schema(\n",
|
" \"text\": google.generativeai.protos.Schema(\n",
|
||||||
" type=google.generativeai.protos.Type.STRING,\n",
|
" type=google.generativeai.protos.Type.STRING,\n",
|
||||||
" description=scraping_function[\"parameters\"][\"properties\"][\"text\"][\"description\"]\n",
|
" description=portable_scraping_function_definition[\"parameters\"][\"properties\"][\"text\"][\"description\"]\n",
|
||||||
" )\n",
|
" )\n",
|
||||||
" },\n",
|
" },\n",
|
||||||
" required=scraping_function[\"parameters\"][\"required\"]\n",
|
" required=portable_scraping_function_definition[\"parameters\"][\"required\"]\n",
|
||||||
" )\n",
|
" )\n",
|
||||||
")]\n"
|
")]\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 74,
|
"execution_count": null,
|
||||||
"id": "aa3fa01b-97d0-443e-b0cc-55d277878cb7",
|
"id": "aa3fa01b-97d0-443e-b0cc-55d277878cb7",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
@@ -296,16 +286,19 @@
|
|||||||
" arguments = json.loads(tool_call['function']['arguments'])\n",
|
" arguments = json.loads(tool_call['function']['arguments'])\n",
|
||||||
" except (json.JSONDecodeError, TypeError):\n",
|
" except (json.JSONDecodeError, TypeError):\n",
|
||||||
" arguments = {'text': tool_call['function'].get('arguments', user_message)}\n",
|
" arguments = {'text': tool_call['function'].get('arguments', user_message)}\n",
|
||||||
" elif hasattr(tool_call, 'function'): # GPT, Claude, Gemini\n",
|
" elif hasattr(tool_call, 'function'): # GPT, Claude\n",
|
||||||
" function_name = tool_call.function.name\n",
|
" function_name = tool_call.function.name\n",
|
||||||
" tool_call_id = getattr(tool_call, 'id', None)\n",
|
" tool_call_id = getattr(tool_call, 'id', None)\n",
|
||||||
" if isinstance(tool_call.function.arguments, dict): # For Gemini\n",
|
" if isinstance(tool_call.function.arguments, dict):\n",
|
||||||
" arguments = tool_call.function.arguments\n",
|
" arguments = tool_call.function.arguments\n",
|
||||||
" else: # For GPT and Claude\n",
|
" else:\n",
|
||||||
" try:\n",
|
" try:\n",
|
||||||
" arguments = json.loads(tool_call.function.arguments)\n",
|
" arguments = json.loads(tool_call.function.arguments)\n",
|
||||||
" except (json.JSONDecodeError, TypeError):\n",
|
" except (json.JSONDecodeError, TypeError):\n",
|
||||||
" arguments = {'text': tool_call.function.arguments}\n",
|
" arguments = {'text': tool_call.function.arguments}\n",
|
||||||
|
" elif hasattr(tool_call, 'name'): # Gemini\n",
|
||||||
|
" function_name = tool_call.name\n",
|
||||||
|
" arguments = tool_call.args\n",
|
||||||
"\n",
|
"\n",
|
||||||
" # Fallback if arguments are not parsed correctly\n",
|
" # Fallback if arguments are not parsed correctly\n",
|
||||||
" if not arguments or 'text' not in arguments:\n",
|
" if not arguments or 'text' not in arguments:\n",
|
||||||
@@ -327,7 +320,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 75,
|
"execution_count": null,
|
||||||
"id": "14083620-1b16-4c8b-8365-c221b831e678",
|
"id": "14083620-1b16-4c8b-8365-c221b831e678",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
@@ -348,7 +341,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 76,
|
"execution_count": null,
|
||||||
"id": "f9601a49-a490-4454-bd47-591ad793dc30",
|
"id": "f9601a49-a490-4454-bd47-591ad793dc30",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
@@ -459,22 +452,42 @@
|
|||||||
" final_response_content = response.content[0].text\n",
|
" final_response_content = response.content[0].text\n",
|
||||||
"\n",
|
"\n",
|
||||||
" elif target_model == \"Gemini\":\n",
|
" elif target_model == \"Gemini\":\n",
|
||||||
|
" messages_gemini = []\n",
|
||||||
|
" for m in history:\n",
|
||||||
|
" messages_gemini.append({\"role\": \"user\", \"parts\": [{\"text\": m[0]}]})\n",
|
||||||
|
" if m[1]:\n",
|
||||||
|
" messages_gemini.append({\"role\": \"model\", \"parts\": [{\"text\": m[1]}]})\n",
|
||||||
|
" \n",
|
||||||
" model = google.generativeai.GenerativeModel(\n",
|
" model = google.generativeai.GenerativeModel(\n",
|
||||||
" model_name=MODEL_GEMINI,\n",
|
" model_name=MODEL_GEMINI,\n",
|
||||||
" system_instruction=system_message,\n",
|
" system_instruction=system_message,\n",
|
||||||
" tools=tools_gemini\n",
|
" tools=tools_gemini\n",
|
||||||
" )\n",
|
" )\n",
|
||||||
" \n",
|
" \n",
|
||||||
" chat = model.start_chat(history=messages)\n",
|
" chat = model.start_chat(history=messages_gemini[:-1])\n",
|
||||||
" response = chat.send_message(user_message)\n",
|
" response = chat.send_message(messages_gemini[-1])\n",
|
||||||
"\n",
|
"\n",
|
||||||
|
" # Check if the response is a tool call before trying to extract text\n",
|
||||||
" if response.candidates[0].content.parts[0].function_call:\n",
|
" if response.candidates[0].content.parts[0].function_call:\n",
|
||||||
" tool_call = response.candidates[0].content.parts[0].function_call\n",
|
" tool_call = response.candidates[0].content.parts[0].function_call\n",
|
||||||
" response_tool = handle_tool_call(tool_call, user_message)\n",
|
" response_tool = handle_tool_call(tool_call, user_message)\n",
|
||||||
"\n",
|
"\n",
|
||||||
" chat.send_message(response_tool)\n",
|
" tool_response_content = json.loads(response_tool[\"content\"])\n",
|
||||||
" response = chat.send_message(user_message)\n",
|
" tool_response_gemini = {\n",
|
||||||
|
" \"role\": \"tool\",\n",
|
||||||
|
" \"parts\": [{\n",
|
||||||
|
" \"function_response\": {\n",
|
||||||
|
" \"name\": tool_call.name,\n",
|
||||||
|
" \"response\": tool_response_content\n",
|
||||||
|
" }\n",
|
||||||
|
" }]\n",
|
||||||
|
" }\n",
|
||||||
" \n",
|
" \n",
|
||||||
|
" # Send the tool output back and get a new response\n",
|
||||||
|
" response = chat.send_message(tool_response_gemini)\n",
|
||||||
|
" final_response_content = response.text\n",
|
||||||
|
" else:\n",
|
||||||
|
" # If the original response was not a tool call, get the text directly\n",
|
||||||
" final_response_content = response.text\n",
|
" final_response_content = response.text\n",
|
||||||
"\n",
|
"\n",
|
||||||
" elif target_model == \"Ollama\":\n",
|
" elif target_model == \"Ollama\":\n",
|
||||||
@@ -576,14 +589,6 @@
|
|||||||
"ui.launch(inbrowser=True)\n"
|
"ui.launch(inbrowser=True)\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"id": "f6e1e727-4c55-4ed5-b50e-5388b246c8c5",
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": []
|
|
||||||
},
|
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": null,
|
"execution_count": null,
|
||||||
|
|||||||
Reference in New Issue
Block a user