added support for multi-model selection

This commit is contained in:
Bharat Puri
2025-10-23 14:58:44 +05:30
parent bcc9342dcd
commit f381c6e10a

View File

@@ -16,7 +16,7 @@
},
{
"cell_type": "code",
"execution_count": 20,
"execution_count": 11,
"id": "e610bf56-a46e-4aff-8de1-ab49d62b1ad3",
"metadata": {},
"outputs": [],
@@ -38,7 +38,7 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": 12,
"id": "4f672e1c-87e9-4865-b760-370fa605e614",
"metadata": {},
"outputs": [
@@ -98,7 +98,7 @@
},
{
"cell_type": "code",
"execution_count": 18,
"execution_count": 13,
"id": "59863df1",
"metadata": {},
"outputs": [],
@@ -122,6 +122,35 @@
"openrouter = OpenAI(api_key=openrouter_api_key, base_url=openrouter_url)\n",
"\n",
"MODEL = os.getenv(\"DOCGEN_MODEL\", \"gpt-4o-mini\")\n",
"\n",
"\n",
"# Registry for multiple model providers\n",
"MODEL_REGISTRY = {\n",
" \"gpt-4o-mini (OpenAI)\": {\n",
" \"provider\": \"openai\",\n",
" \"model\": \"gpt-4o-mini\",\n",
" },\n",
" \"gpt-4o (OpenAI)\": {\n",
" \"provider\": \"openai\",\n",
" \"model\": \"gpt-4o\",\n",
" },\n",
" \"claude-3.5-sonnet (Anthropic)\": {\n",
" \"provider\": \"anthropic\",\n",
" \"model\": \"claude-3.5-sonnet\",\n",
" },\n",
" \"gemini-1.5-pro (Google)\": {\n",
" \"provider\": \"google\",\n",
" \"model\": \"gemini-1.5-pro\",\n",
" },\n",
" \"codellama-7b (Open Source)\": {\n",
" \"provider\": \"open_source\",\n",
" \"model\": \"codellama-7b\",\n",
" },\n",
" \"starcoder2 (Open Source)\": {\n",
" \"provider\": \"open_source\",\n",
" \"model\": \"starcoder2\",\n",
" },\n",
"}\n",
"\n"
]
},
@@ -141,7 +170,7 @@
},
{
"cell_type": "code",
"execution_count": 8,
"execution_count": 5,
"id": "17b7d074-b1a4-4673-adec-918f82a4eff0",
"metadata": {},
"outputs": [],
@@ -189,7 +218,7 @@
},
{
"cell_type": "code",
"execution_count": 15,
"execution_count": 14,
"id": "16b3c10f-f7bc-4a2f-a22f-65c6807b7574",
"metadata": {},
"outputs": [],
@@ -197,33 +226,49 @@
"# ================================================================\n",
"# LLM Chat Helper — OpenAI GPT\n",
"# ================================================================\n",
"\n",
"def llm_generate_docstring(signature: str, context: str, style: str = \"google\", module_name: str = \"module\") -> str:\n",
"def llm_generate_docstring(signature: str, context: str, style: str = \"google\", \n",
" module_name: str = \"module\", model_choice: str = \"gpt-4o-mini (OpenAI)\") -> str:\n",
" \"\"\"\n",
" Sends a chat completion request to OpenAI GPT model to generate\n",
" a docstring based on code context and function signature.\n",
" Generate a Python docstring using the selected model provider.\n",
" \"\"\"\n",
" user_prompt = make_user_prompt(style, module_name, signature, context)\n",
" model_info = MODEL_REGISTRY.get(model_choice, MODEL_REGISTRY[\"gpt-4o-mini (OpenAI)\"])\n",
"\n",
" provider = model_info[\"provider\"]\n",
" model_name = model_info[\"model\"]\n",
"\n",
" if provider == \"openai\":\n",
" response = openai.chat.completions.create(\n",
" model=MODEL,\n",
" model=model_name,\n",
" temperature=0.2,\n",
" messages=[\n",
" {\"role\": \"system\", \"content\": \"You are a senior Python engineer and technical writer.\"},\n",
" {\"role\": \"user\", \"content\": user_prompt}\n",
" ]\n",
" {\"role\": \"user\", \"content\": user_prompt},\n",
" ],\n",
" )\n",
"\n",
" text = response.choices[0].message.content.strip()\n",
" # Extract only the text inside triple quotes if present\n",
"\n",
" elif provider == \"anthropic\":\n",
" # Future: integrate Anthropic SDK\n",
" text = \"Claude response simulation: \" + user_prompt[:200]\n",
"\n",
" elif provider == \"google\":\n",
" # Future: integrate Gemini API\n",
" text = \"Gemini response simulation: \" + user_prompt[:200]\n",
"\n",
" else:\n",
" # Simulated open-source fallback\n",
" text = f\"[Simulated output from {model_name}]\\nAuto-generated docstring for {signature}\"\n",
"\n",
" import re\n",
" match = re.search(r'\"\"\"(.*?)\"\"\"', text, re.S)\n",
" return (match.group(1).strip() if match else text)\n",
" return match.group(1).strip() if match else text\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 11,
"execution_count": 15,
"id": "82da91ac-e563-4425-8b45-1b94880d342f",
"metadata": {},
"outputs": [],
@@ -298,7 +343,7 @@
},
{
"cell_type": "code",
"execution_count": 12,
"execution_count": 16,
"id": "ea69108f-e4ca-4326-89fe-97c5748c0e79",
"metadata": {},
"outputs": [
@@ -332,7 +377,7 @@
},
{
"cell_type": "code",
"execution_count": 16,
"execution_count": 17,
"id": "00d65b96-e65d-4e11-89be-06f265a5f2e3",
"metadata": {},
"outputs": [],
@@ -379,11 +424,7 @@
" )\n",
"\n",
"\n",
"def generate_docstrings_for_source(src: str, style: str = \"google\", module_name: str = \"module\"):\n",
" \"\"\"\n",
" Find all missing docstrings, generate them via LLM,\n",
" and insert them back into the source code.\n",
" \"\"\"\n",
"def generate_docstrings_for_source(src: str, style: str = \"google\", module_name: str = \"module\", model_choice: str = \"gpt-4o-mini (OpenAI)\"):\n",
" targets = find_missing_docstrings(src)\n",
" updated = src\n",
" report = []\n",
@@ -391,100 +432,24 @@
" for kind, node in sorted(targets, key=lambda t: 0 if t[0] == \"module\" else 1):\n",
" sig = \"module \" + module_name if kind == \"module\" else node_signature(node)\n",
" ctx = src if kind == \"module\" else context_snippet(src, node)\n",
" doc = llm_generate_docstring(sig, ctx, style=style, module_name=module_name)\n",
" doc = llm_generate_docstring(sig, ctx, style=style, module_name=module_name, model_choice=model_choice)\n",
"\n",
" if kind == \"module\":\n",
" updated = insert_module_docstring(updated, doc)\n",
" else:\n",
" updated = insert_docstring(updated, node, doc)\n",
"\n",
" report.append({\"kind\": kind, \"signature\": sig, \"doc_preview\": doc[:150]})\n",
" report.append({\"kind\": kind, \"signature\": sig, \"model\": model_choice, \"doc_preview\": doc[:150]})\n",
"\n",
" return updated, report\n"
]
},
{
"cell_type": "code",
"execution_count": 21,
"execution_count": null,
"id": "d00cf4b7-773d-49cb-8262-9d11d787ee10",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"=== Generated Docstrings ===\n",
"- module: module demo\n",
" Adds two numbers and returns the result.\n",
"\n",
"Args:\n",
" x: The first number to add.\n",
" y: The second number to add.\n",
"Returns:\n",
" The sum of x and y.\n",
"- function: def add(x, y):\n",
" Returns the sum of two numbers.\n",
"\n",
"This function takes two numerical inputs and returns their sum. \n",
"It supports both integers and floats.\n",
"\n",
"Args:\n",
" x: \n",
"- class: class Counter:\n",
" A simple counter class to track increments.\n",
"\n",
"This class provides a method to increment a total count. \n",
"It initializes the total count to zero and allo\n",
"- function: def inc(self):\n",
" Increments the total attribute by one.\n",
"\n",
"This method updates the instance's total value, which is expected to be an integer, by adding one to it. It is\n",
"\n",
"=== Updated Source ===\n",
"\"\"\"Adds two numbers and returns the result.\n",
"\n",
"\"\"\"Returns the sum of two numbers.\n",
"\n",
"This function takes two numerical inputs and returns their sum. \n",
"\"\"\"A simple counter class to track increments.\n",
"\"\"\"Increments the total attribute by one.\n",
"\n",
"This method updates the instance's total value, which is expected to be an integer, by adding one to it. It is typically used to track counts or totals within the class context.\"\"\"\n",
"\n",
"\n",
"This class provides a method to increment a total count. \n",
"It initializes the total count to zero and allows for \n",
"incrementing it by one each time the `inc` method is called.\n",
"\n",
"Args:\n",
" None\n",
"Returns:\n",
" None\"\"\"\n",
"\n",
"It supports both integers and floats.\n",
"\n",
"Args:\n",
" x: The first number to add.\n",
" y: The second number to add.\n",
"\n",
"Returns:\n",
" The sum of x and y.\"\"\"\n",
"\n",
"Args:\n",
" x: The first number to add.\n",
" y: The second number to add.\n",
"Returns:\n",
" The sum of x and y.\"\"\"\n",
"\n",
"def add(x, y):\n",
" return x + y\n",
"\n",
"class Counter:\n",
" def inc(self):\n",
" self.total += 1\n"
]
}
],
"outputs": [],
"source": [
"## Quick Test ##\n",
"new_code, report = generate_docstrings_for_source(code, style=\"google\", module_name=\"demo\")\n",
@@ -499,86 +464,10 @@
},
{
"cell_type": "code",
"execution_count": 21,
"execution_count": null,
"id": "b318db41-c05d-48ce-9990-b6f1a0577c68",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"=== Generated Docstrings ===\n",
"- module: module demo\n",
" Adds two numbers and returns the result.\n",
"\n",
"Args:\n",
" x: The first number to add.\n",
" y: The second number to add.\n",
"Returns:\n",
" The sum of x and y.\n",
"- function: def add(x, y):\n",
" Returns the sum of two numbers.\n",
"\n",
"This function takes two numerical inputs and returns their sum. \n",
"It supports both integers and floats.\n",
"\n",
"Args:\n",
" x: \n",
"- class: class Counter:\n",
" A simple counter class to track increments.\n",
"\n",
"This class provides a method to increment a total count. \n",
"It initializes the total count to zero and allo\n",
"- function: def inc(self):\n",
" Increments the total attribute by one.\n",
"\n",
"This method updates the instance's total value, which is expected to be an integer, by adding one to it. It is\n",
"\n",
"=== Updated Source ===\n",
"\"\"\"Adds two numbers and returns the result.\n",
"\n",
"\"\"\"Returns the sum of two numbers.\n",
"\n",
"This function takes two numerical inputs and returns their sum. \n",
"\"\"\"A simple counter class to track increments.\n",
"\"\"\"Increments the total attribute by one.\n",
"\n",
"This method updates the instance's total value, which is expected to be an integer, by adding one to it. It is typically used to track counts or totals within the class context.\"\"\"\n",
"\n",
"\n",
"This class provides a method to increment a total count. \n",
"It initializes the total count to zero and allows for \n",
"incrementing it by one each time the `inc` method is called.\n",
"\n",
"Args:\n",
" None\n",
"Returns:\n",
" None\"\"\"\n",
"\n",
"It supports both integers and floats.\n",
"\n",
"Args:\n",
" x: The first number to add.\n",
" y: The second number to add.\n",
"\n",
"Returns:\n",
" The sum of x and y.\"\"\"\n",
"\n",
"Args:\n",
" x: The first number to add.\n",
" y: The second number to add.\n",
"Returns:\n",
" The sum of x and y.\"\"\"\n",
"\n",
"def add(x, y):\n",
" return x + y\n",
"\n",
"class Counter:\n",
" def inc(self):\n",
" self.total += 1\n"
]
}
],
"outputs": [],
"source": [
"# ================================================================\n",
"# 📂 File-Based Workflow — preview or apply docstrings\n",
@@ -613,7 +502,7 @@
},
{
"cell_type": "code",
"execution_count": 22,
"execution_count": null,
"id": "8962cf0e-9255-475e-bbc1-21500be0cd78",
"metadata": {},
"outputs": [],
@@ -651,86 +540,50 @@
},
{
"cell_type": "code",
"execution_count": 24,
"execution_count": null,
"id": "b0b0f852-982f-4918-9b5d-89880cc12003",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"* Running on local URL: http://127.0.0.1:7864\n",
"* To create a public link, set `share=True` in `launch()`.\n"
]
},
{
"data": {
"text/html": [
"<div><iframe src=\"http://127.0.0.1:7864/\" width=\"100%\" height=\"500\" allow=\"autoplay; camera; microphone; clipboard-read; clipboard-write;\" frameborder=\"0\" allowfullscreen></iframe></div>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": []
},
"execution_count": 24,
"metadata": {},
"output_type": "execute_result"
}
],
"outputs": [],
"source": [
"# ================================================================\n",
"# 🎨 Gradio Interface — Auto Docstring Generator\n",
"# 🎨 Enhanced Gradio Interface with Model Selector\n",
"# ================================================================\n",
"import gradio as gr\n",
"\n",
"def gradio_generate(code_text: str, style: str):\n",
" \"\"\"Wrapper for Gradio — generates docstrings for pasted code.\"\"\"\n",
"def gradio_generate(code_text: str, style: str, model_choice: str):\n",
" \"\"\"Wrapper for Gradio — generates docstrings using selected model.\"\"\"\n",
" if not code_text.strip():\n",
" return \"⚠️ Please paste some Python code first.\"\n",
" try:\n",
" updated, _ = generate_docstrings_for_source(\n",
" code_text, style=style, module_name=\"gradio_snippet\"\n",
" code_text, style=style, module_name=\"gradio_snippet\", model_choice=model_choice\n",
" )\n",
" return updated\n",
" except Exception as e:\n",
" return f\"❌ Error: {e}\"\n",
"\n",
"# Build Gradio UI\n",
"with gr.Blocks(theme=gr.themes.Soft()) as doc_ui:\n",
" gr.Markdown(\"## 🧠 Auto Docstring Generator — by Bharat Puri\\nPaste your Python code below and click **Generate Docstrings**.\")\n",
" gr.Markdown(\"## 🧠 Auto Docstring Generator — by Bharat Puri\\nChoose your model and generate high-quality docstrings.\")\n",
"\n",
" with gr.Row():\n",
" code_input = gr.Code(\n",
" label=\"Paste your Python code here\",\n",
" language=\"python\",\n",
" lines=20,\n",
" value=\"def add(a, b):\\n return a + b\\n\\nclass Counter:\\n def inc(self):\\n self.total += 1\",\n",
" )\n",
" code_output = gr.Code(\n",
" label=\"Generated code with docstrings\",\n",
" language=\"python\",\n",
" lines=20,\n",
" )\n",
" code_input = gr.Code(label=\"Paste your Python code\", language=\"python\", lines=18)\n",
" code_output = gr.Code(label=\"Generated code with docstrings\", language=\"python\", lines=18)\n",
"\n",
" style_choice = gr.Radio(\n",
" [\"google\"], value=\"google\", label=\"Docstring Style\"\n",
" with gr.Row():\n",
" style_choice = gr.Radio([\"google\"], value=\"google\", label=\"Docstring Style\")\n",
" model_choice = gr.Dropdown(\n",
" list(MODEL_REGISTRY.keys()),\n",
" value=\"gpt-4o-mini (OpenAI)\",\n",
" label=\"Select Model\",\n",
" )\n",
"\n",
" generate_btn = gr.Button(\"🚀 Generate Docstrings\")\n",
" generate_btn.click(\n",
" fn=gradio_generate,\n",
" inputs=[code_input, style_choice],\n",
" inputs=[code_input, style_choice, model_choice],\n",
" outputs=[code_output],\n",
" )\n",
"\n",
"# Launch app\n",
"doc_ui.launch(share=False)\n"
]
}