301 lines
8.6 KiB
Plaintext
301 lines
8.6 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "45ca91c2",
|
|
"metadata": {},
|
|
"source": [
|
|
"# AI tool to add comments to the provided Java code\n",
|
|
"\n",
|
|
"Here we build a Gradio App that uses the frontier models to add comments to a java code. For testing purposes I have used the *cheaper* versions of the models, not the ones the leaderboards indicate as the best ones."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "f44901f5",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# imports\n",
|
|
"\n",
|
|
"import os\n",
|
|
"from dotenv import load_dotenv\n",
|
|
"from openai import OpenAI\n",
|
|
"import google.generativeai as genai\n",
|
|
"import anthropic\n",
|
|
"import gradio as gr"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "c47706b3",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# environment\n",
|
|
"\n",
|
|
"load_dotenv(override=True)\n",
|
|
"openai_api_key = os.getenv('OPENAI_API_KEY')\n",
|
|
"anthropic_api_key = os.getenv('ANTHROPIC_API_KEY')\n",
|
|
"google_api_key = os.getenv('GOOGLE_API_KEY')"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "35446b9a",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"openai = OpenAI()\n",
|
|
"claude = anthropic.Anthropic()\n",
|
|
"genai.configure()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "0e899efd",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"OPENAI_MODEL = \"gpt-4o-mini\"\n",
|
|
"CLAUDE_MODEL = \"claude-3-haiku-20240307\"\n",
|
|
"GEMINI_MODEL = 'gemini-2.0-flash-lite'"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "47640f53",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"system_message = \"You are an assistant that adds comments to java code. \"\n",
|
|
"system_message += \"Do not make any changes to the code itself.\"\n",
|
|
"system_message += \"Use comments sparingly. Only add them in places where they help to undestand how the code works. Do not comment every single line of the code.\""
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "f41ccbf0",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def user_prompt_for(code):\n",
|
|
" user_prompt = \"Add helpful comments to this java code. \"\n",
|
|
" user_prompt += \"Do not change the code itself.\\n\\n\"\n",
|
|
" user_prompt += code\n",
|
|
" return user_prompt"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "c57c0000",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"test_code = \"\"\"\n",
|
|
"package com.hma.kafkaproducertest.producer;\n",
|
|
"\n",
|
|
"import com.hma.kafkaproducertest.model.TestDTO;\n",
|
|
"import org.springframework.cloud.stream.function.StreamBridge;\n",
|
|
"import org.springframework.messaging.Message;\n",
|
|
"import org.springframework.messaging.support.MessageBuilder;\n",
|
|
"import org.springframework.stereotype.Component;\n",
|
|
"\n",
|
|
"import java.util.Arrays;\n",
|
|
"import java.util.Comparator;\n",
|
|
"import java.util.StringJoiner;\n",
|
|
"import java.util.stream.Collectors;\n",
|
|
"import java.util.stream.IntStream;\n",
|
|
"\n",
|
|
"@Component\n",
|
|
"public class TestProducer {\n",
|
|
"\n",
|
|
" public static final String EVENT_TYPE_HEADER = \"event-type\";\n",
|
|
" private static final String BINDING_NAME = \"testProducer-out-0\";\n",
|
|
"\n",
|
|
" private final StreamBridge streamBridge;\n",
|
|
"\n",
|
|
" public TestProducer(StreamBridge streamBridge) {\n",
|
|
" this.streamBridge = streamBridge;\n",
|
|
" }\n",
|
|
"\n",
|
|
" public void sendMessage(TestDTO payload, String eventType){\n",
|
|
" Message<TestDTO> message = MessageBuilder\n",
|
|
" .withPayload(payload)\n",
|
|
" .setHeader(EVENT_TYPE_HEADER, eventType)\n",
|
|
" .build();\n",
|
|
"\n",
|
|
" streamBridge.send(BINDING_NAME, message);\n",
|
|
" }\n",
|
|
"\n",
|
|
" public void test(String t1, String t2) {\n",
|
|
" var s = t1.length() > t2.length() ? t2 : t1;\n",
|
|
" var l = t1.length() > t2.length() ? t1 : t2;\n",
|
|
" var res = true;\n",
|
|
" for (int i = 0; i < s.length(); i++) {\n",
|
|
" if (s.charAt(i) == l.charAt(i)) {\n",
|
|
" res = false;\n",
|
|
" break;\n",
|
|
" }\n",
|
|
" }\n",
|
|
" System.out.println(res);\n",
|
|
" }\n",
|
|
"}\n",
|
|
"\"\"\""
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "00c71128",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def stream_gpt(code):\n",
|
|
" messages = [\n",
|
|
" {\"role\": \"system\", \"content\": system_message},\n",
|
|
" {\"role\": \"user\", \"content\": user_prompt_for(code)}\n",
|
|
" ]\n",
|
|
" stream = openai.chat.completions.create(\n",
|
|
" model=OPENAI_MODEL,\n",
|
|
" messages=messages,\n",
|
|
" stream=True\n",
|
|
" )\n",
|
|
" result = \"\"\n",
|
|
" for chunk in stream:\n",
|
|
" result += chunk.choices[0].delta.content or \"\"\n",
|
|
" yield result"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "ca92f8a8",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def stream_claude(code):\n",
|
|
" result = claude.messages.stream(\n",
|
|
" model=CLAUDE_MODEL,\n",
|
|
" max_tokens=2000,\n",
|
|
" system=system_message,\n",
|
|
" messages=[\n",
|
|
" {\"role\": \"user\", \"content\": user_prompt_for(code)},\n",
|
|
" ],\n",
|
|
" )\n",
|
|
" response = \"\"\n",
|
|
" with result as stream:\n",
|
|
" for text in stream.text_stream:\n",
|
|
" response += text or \"\"\n",
|
|
" yield response"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "9dffed4b",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def stream_gemini(code):\n",
|
|
" gemini = genai.GenerativeModel(\n",
|
|
" model_name=GEMINI_MODEL,\n",
|
|
" system_instruction=system_message\n",
|
|
" )\n",
|
|
" stream = gemini.generate_content(user_prompt_for(code), stream=True)\n",
|
|
" result = \"\"\n",
|
|
" for chunk in stream:\n",
|
|
" result += chunk.text or \"\"\n",
|
|
" yield result"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "31f9c267",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def comment_code(code, model):\n",
|
|
" if model==\"GPT\":\n",
|
|
" result = stream_gpt(code)\n",
|
|
" elif model==\"Claude\":\n",
|
|
" result = stream_claude(code)\n",
|
|
" elif model==\"Gemini\":\n",
|
|
" result = stream_gemini(code)\n",
|
|
" else:\n",
|
|
" raise ValueError(\"Unknown model\")\n",
|
|
" yield from result"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "c04c0a1b",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"with gr.Blocks() as ui:\n",
|
|
" with gr.Row():\n",
|
|
" original_code = gr.Textbox(label=\"Java code:\", lines=10, value=test_code)\n",
|
|
" commented_code = gr.Markdown(label=\"Commented code:\")\n",
|
|
" with gr.Row():\n",
|
|
" model = gr.Dropdown([\"GPT\", \"Claude\", \"Gemini\"], label=\"Select model\", value=\"GPT\")\n",
|
|
" comment = gr.Button(\"Comment code\")\n",
|
|
"\n",
|
|
" comment.click(comment_code, inputs=[original_code, model], outputs=[commented_code])\n",
|
|
"\n",
|
|
"ui.launch(inbrowser=True)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "84d33a5f",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"ui.close()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "bbd50bf7",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Conclusion\n",
|
|
"\n",
|
|
"In my personal opinion, at least when using these *cheaper* versions of the models, the result provided by Claude is the best. ChatGPT adds way too many comments even if the system message discourages that. Gemini provides a good result also, but maybe adds a tad too few comments -- although that certainly depends on your personal preferences."
|
|
]
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "llms",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.11.12"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 5
|
|
}
|