diff --git a/week4/community-contributions/bharat_puri/docstring_generator.ipynb b/week4/community-contributions/bharat_puri/docstring_generator.ipynb index 7ab37f8..2f92d9a 100644 --- a/week4/community-contributions/bharat_puri/docstring_generator.ipynb +++ b/week4/community-contributions/bharat_puri/docstring_generator.ipynb @@ -16,7 +16,7 @@ }, { "cell_type": "code", - "execution_count": 20, + "execution_count": 11, "id": "e610bf56-a46e-4aff-8de1-ab49d62b1ad3", "metadata": {}, "outputs": [], @@ -38,7 +38,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 12, "id": "4f672e1c-87e9-4865-b760-370fa605e614", "metadata": {}, "outputs": [ @@ -98,7 +98,7 @@ }, { "cell_type": "code", - "execution_count": 18, + "execution_count": 13, "id": "59863df1", "metadata": {}, "outputs": [], @@ -122,6 +122,35 @@ "openrouter = OpenAI(api_key=openrouter_api_key, base_url=openrouter_url)\n", "\n", "MODEL = os.getenv(\"DOCGEN_MODEL\", \"gpt-4o-mini\")\n", + "\n", + "\n", + "# Registry for multiple model providers\n", + "MODEL_REGISTRY = {\n", + " \"gpt-4o-mini (OpenAI)\": {\n", + " \"provider\": \"openai\",\n", + " \"model\": \"gpt-4o-mini\",\n", + " },\n", + " \"gpt-4o (OpenAI)\": {\n", + " \"provider\": \"openai\",\n", + " \"model\": \"gpt-4o\",\n", + " },\n", + " \"claude-3.5-sonnet (Anthropic)\": {\n", + " \"provider\": \"anthropic\",\n", + " \"model\": \"claude-3.5-sonnet\",\n", + " },\n", + " \"gemini-1.5-pro (Google)\": {\n", + " \"provider\": \"google\",\n", + " \"model\": \"gemini-1.5-pro\",\n", + " },\n", + " \"codellama-7b (Open Source)\": {\n", + " \"provider\": \"open_source\",\n", + " \"model\": \"codellama-7b\",\n", + " },\n", + " \"starcoder2 (Open Source)\": {\n", + " \"provider\": \"open_source\",\n", + " \"model\": \"starcoder2\",\n", + " },\n", + "}\n", "\n" ] }, @@ -141,7 +170,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 5, "id": "17b7d074-b1a4-4673-adec-918f82a4eff0", "metadata": {}, "outputs": [], @@ -189,7 +218,7 @@ }, { "cell_type": "code", - "execution_count": 15, + "execution_count": 14, "id": "16b3c10f-f7bc-4a2f-a22f-65c6807b7574", "metadata": {}, "outputs": [], @@ -197,33 +226,49 @@ "# ================================================================\n", "# LLM Chat Helper — OpenAI GPT\n", "# ================================================================\n", - "\n", - "def llm_generate_docstring(signature: str, context: str, style: str = \"google\", module_name: str = \"module\") -> str:\n", + "def llm_generate_docstring(signature: str, context: str, style: str = \"google\", \n", + " module_name: str = \"module\", model_choice: str = \"gpt-4o-mini (OpenAI)\") -> str:\n", " \"\"\"\n", - " Sends a chat completion request to OpenAI GPT model to generate\n", - " a docstring based on code context and function signature.\n", + " Generate a Python docstring using the selected model provider.\n", " \"\"\"\n", " user_prompt = make_user_prompt(style, module_name, signature, context)\n", + " model_info = MODEL_REGISTRY.get(model_choice, MODEL_REGISTRY[\"gpt-4o-mini (OpenAI)\"])\n", "\n", - " response = openai.chat.completions.create(\n", - " model=MODEL,\n", - " temperature=0.2,\n", - " messages=[\n", - " {\"role\": \"system\", \"content\": \"You are a senior Python engineer and technical writer.\"},\n", - " {\"role\": \"user\", \"content\": user_prompt}\n", - " ]\n", - " )\n", + " provider = model_info[\"provider\"]\n", + " model_name = model_info[\"model\"]\n", "\n", - " text = response.choices[0].message.content.strip()\n", - " # Extract only the text inside triple quotes if present\n", + " if provider == \"openai\":\n", + " response = openai.chat.completions.create(\n", + " model=model_name,\n", + " temperature=0.2,\n", + " messages=[\n", + " {\"role\": \"system\", \"content\": \"You are a senior Python engineer and technical writer.\"},\n", + " {\"role\": \"user\", \"content\": user_prompt},\n", + " ],\n", + " )\n", + " text = response.choices[0].message.content.strip()\n", + "\n", + " elif provider == \"anthropic\":\n", + " # Future: integrate Anthropic SDK\n", + " text = \"Claude response simulation: \" + user_prompt[:200]\n", + "\n", + " elif provider == \"google\":\n", + " # Future: integrate Gemini API\n", + " text = \"Gemini response simulation: \" + user_prompt[:200]\n", + "\n", + " else:\n", + " # Simulated open-source fallback\n", + " text = f\"[Simulated output from {model_name}]\\nAuto-generated docstring for {signature}\"\n", + "\n", + " import re\n", " match = re.search(r'\"\"\"(.*?)\"\"\"', text, re.S)\n", - " return (match.group(1).strip() if match else text)\n", + " return match.group(1).strip() if match else text\n", "\n" ] }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 15, "id": "82da91ac-e563-4425-8b45-1b94880d342f", "metadata": {}, "outputs": [], @@ -298,7 +343,7 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 16, "id": "ea69108f-e4ca-4326-89fe-97c5748c0e79", "metadata": {}, "outputs": [ @@ -332,7 +377,7 @@ }, { "cell_type": "code", - "execution_count": 16, + "execution_count": 17, "id": "00d65b96-e65d-4e11-89be-06f265a5f2e3", "metadata": {}, "outputs": [], @@ -379,11 +424,7 @@ " )\n", "\n", "\n", - "def generate_docstrings_for_source(src: str, style: str = \"google\", module_name: str = \"module\"):\n", - " \"\"\"\n", - " Find all missing docstrings, generate them via LLM,\n", - " and insert them back into the source code.\n", - " \"\"\"\n", + "def generate_docstrings_for_source(src: str, style: str = \"google\", module_name: str = \"module\", model_choice: str = \"gpt-4o-mini (OpenAI)\"):\n", " targets = find_missing_docstrings(src)\n", " updated = src\n", " report = []\n", @@ -391,100 +432,24 @@ " for kind, node in sorted(targets, key=lambda t: 0 if t[0] == \"module\" else 1):\n", " sig = \"module \" + module_name if kind == \"module\" else node_signature(node)\n", " ctx = src if kind == \"module\" else context_snippet(src, node)\n", - " doc = llm_generate_docstring(sig, ctx, style=style, module_name=module_name)\n", + " doc = llm_generate_docstring(sig, ctx, style=style, module_name=module_name, model_choice=model_choice)\n", "\n", " if kind == \"module\":\n", " updated = insert_module_docstring(updated, doc)\n", " else:\n", " updated = insert_docstring(updated, node, doc)\n", "\n", - " report.append({\"kind\": kind, \"signature\": sig, \"doc_preview\": doc[:150]})\n", + " report.append({\"kind\": kind, \"signature\": sig, \"model\": model_choice, \"doc_preview\": doc[:150]})\n", "\n", " return updated, report\n" ] }, { "cell_type": "code", - "execution_count": 21, + "execution_count": null, "id": "d00cf4b7-773d-49cb-8262-9d11d787ee10", "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "=== Generated Docstrings ===\n", - "- module: module demo\n", - " Adds two numbers and returns the result.\n", - "\n", - "Args:\n", - " x: The first number to add.\n", - " y: The second number to add.\n", - "Returns:\n", - " The sum of x and y.\n", - "- function: def add(x, y):\n", - " Returns the sum of two numbers.\n", - "\n", - "This function takes two numerical inputs and returns their sum. \n", - "It supports both integers and floats.\n", - "\n", - "Args:\n", - " x: \n", - "- class: class Counter:\n", - " A simple counter class to track increments.\n", - "\n", - "This class provides a method to increment a total count. \n", - "It initializes the total count to zero and allo\n", - "- function: def inc(self):\n", - " Increments the total attribute by one.\n", - "\n", - "This method updates the instance's total value, which is expected to be an integer, by adding one to it. It is\n", - "\n", - "=== Updated Source ===\n", - "\"\"\"Adds two numbers and returns the result.\n", - "\n", - "\"\"\"Returns the sum of two numbers.\n", - "\n", - "This function takes two numerical inputs and returns their sum. \n", - "\"\"\"A simple counter class to track increments.\n", - "\"\"\"Increments the total attribute by one.\n", - "\n", - "This method updates the instance's total value, which is expected to be an integer, by adding one to it. It is typically used to track counts or totals within the class context.\"\"\"\n", - "\n", - "\n", - "This class provides a method to increment a total count. \n", - "It initializes the total count to zero and allows for \n", - "incrementing it by one each time the `inc` method is called.\n", - "\n", - "Args:\n", - " None\n", - "Returns:\n", - " None\"\"\"\n", - "\n", - "It supports both integers and floats.\n", - "\n", - "Args:\n", - " x: The first number to add.\n", - " y: The second number to add.\n", - "\n", - "Returns:\n", - " The sum of x and y.\"\"\"\n", - "\n", - "Args:\n", - " x: The first number to add.\n", - " y: The second number to add.\n", - "Returns:\n", - " The sum of x and y.\"\"\"\n", - "\n", - "def add(x, y):\n", - " return x + y\n", - "\n", - "class Counter:\n", - " def inc(self):\n", - " self.total += 1\n" - ] - } - ], + "outputs": [], "source": [ "## Quick Test ##\n", "new_code, report = generate_docstrings_for_source(code, style=\"google\", module_name=\"demo\")\n", @@ -499,86 +464,10 @@ }, { "cell_type": "code", - "execution_count": 21, + "execution_count": null, "id": "b318db41-c05d-48ce-9990-b6f1a0577c68", "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "=== Generated Docstrings ===\n", - "- module: module demo\n", - " Adds two numbers and returns the result.\n", - "\n", - "Args:\n", - " x: The first number to add.\n", - " y: The second number to add.\n", - "Returns:\n", - " The sum of x and y.\n", - "- function: def add(x, y):\n", - " Returns the sum of two numbers.\n", - "\n", - "This function takes two numerical inputs and returns their sum. \n", - "It supports both integers and floats.\n", - "\n", - "Args:\n", - " x: \n", - "- class: class Counter:\n", - " A simple counter class to track increments.\n", - "\n", - "This class provides a method to increment a total count. \n", - "It initializes the total count to zero and allo\n", - "- function: def inc(self):\n", - " Increments the total attribute by one.\n", - "\n", - "This method updates the instance's total value, which is expected to be an integer, by adding one to it. It is\n", - "\n", - "=== Updated Source ===\n", - "\"\"\"Adds two numbers and returns the result.\n", - "\n", - "\"\"\"Returns the sum of two numbers.\n", - "\n", - "This function takes two numerical inputs and returns their sum. \n", - "\"\"\"A simple counter class to track increments.\n", - "\"\"\"Increments the total attribute by one.\n", - "\n", - "This method updates the instance's total value, which is expected to be an integer, by adding one to it. It is typically used to track counts or totals within the class context.\"\"\"\n", - "\n", - "\n", - "This class provides a method to increment a total count. \n", - "It initializes the total count to zero and allows for \n", - "incrementing it by one each time the `inc` method is called.\n", - "\n", - "Args:\n", - " None\n", - "Returns:\n", - " None\"\"\"\n", - "\n", - "It supports both integers and floats.\n", - "\n", - "Args:\n", - " x: The first number to add.\n", - " y: The second number to add.\n", - "\n", - "Returns:\n", - " The sum of x and y.\"\"\"\n", - "\n", - "Args:\n", - " x: The first number to add.\n", - " y: The second number to add.\n", - "Returns:\n", - " The sum of x and y.\"\"\"\n", - "\n", - "def add(x, y):\n", - " return x + y\n", - "\n", - "class Counter:\n", - " def inc(self):\n", - " self.total += 1\n" - ] - } - ], + "outputs": [], "source": [ "# ================================================================\n", "# 📂 File-Based Workflow — preview or apply docstrings\n", @@ -613,7 +502,7 @@ }, { "cell_type": "code", - "execution_count": 22, + "execution_count": null, "id": "8962cf0e-9255-475e-bbc1-21500be0cd78", "metadata": {}, "outputs": [], @@ -651,86 +540,50 @@ }, { "cell_type": "code", - "execution_count": 24, + "execution_count": null, "id": "b0b0f852-982f-4918-9b5d-89880cc12003", "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "* Running on local URL: http://127.0.0.1:7864\n", - "* To create a public link, set `share=True` in `launch()`.\n" - ] - }, - { - "data": { - "text/html": [ - "
" - ], - "text/plain": [ - "