477 lines
16 KiB
Plaintext
477 lines
16 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "xeOG96gXPeqz"
|
|
},
|
|
"source": [
|
|
"# Snippet Sniper\n",
|
|
"\n",
|
|
"### Welcome on a wild ride with the John Wick in the coding arena as it accepts your contracts \n",
|
|
"\n",
|
|
"Allows you to perform various tasks on given code snippets:\n",
|
|
"\n",
|
|
"- Add comments\n",
|
|
"- Explain what the code does\n",
|
|
"- Writes comprehensive unit tests\n",
|
|
"- Fixes (potential) errors in the code"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"colab": {
|
|
"base_uri": "https://localhost:8080/"
|
|
},
|
|
"id": "B7ftYo53Pw94",
|
|
"outputId": "9daa3972-d5a1-4cd2-9952-cd89a54c6ddd"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"import os\n",
|
|
"import logging\n",
|
|
"from enum import StrEnum\n",
|
|
"from getpass import getpass\n",
|
|
"\n",
|
|
"import gradio as gr\n",
|
|
"from openai import OpenAI\n",
|
|
"from dotenv import load_dotenv\n",
|
|
"\n",
|
|
"\n",
|
|
"load_dotenv(override=True)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"id": "AXmPDuydPuUp"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"logging.basicConfig(level=logging.WARNING)\n",
|
|
"\n",
|
|
"logger = logging.getLogger('sniper')\n",
|
|
"logger.setLevel(logging.DEBUG)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "0c_e1iMYmp5o"
|
|
},
|
|
"source": [
|
|
"## Free Cloud Providers\n",
|
|
"\n",
|
|
"Grab your free API Keys from these generous sites:\n",
|
|
"\n",
|
|
"- https://ollama.com/"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"### Secrets Helpers"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def get_secret_in_google_colab(env_name: str) -> str:\n",
|
|
" try:\n",
|
|
" from google.colab import userdata\n",
|
|
" return userdata.get(env_name)\n",
|
|
" except Exception:\n",
|
|
" return ''\n",
|
|
"\n",
|
|
"\n",
|
|
"def get_secret(env_name: str) -> str:\n",
|
|
" '''Gets the value from the environment(s), otherwise ask the user for it if not set'''\n",
|
|
" key = os.environ.get(env_name) or get_secret_in_google_colab(env_name)\n",
|
|
"\n",
|
|
" if not key:\n",
|
|
" key = getpass(f'Enter {env_name}:').strip()\n",
|
|
"\n",
|
|
" if key:\n",
|
|
" logger.info(f'✅ {env_name} provided')\n",
|
|
" else:\n",
|
|
" logger.warning(f'❌ {env_name} not provided')\n",
|
|
" return key.strip()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"### Set up model(s)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"colab": {
|
|
"base_uri": "https://localhost:8080/"
|
|
},
|
|
"id": "d7Qmfac9Ph0w",
|
|
"outputId": "be9db7f3-f08a-47f5-d6fa-d7c8bce4f97a"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"class Provider(StrEnum):\n",
|
|
" OLLAMA = 'Ollama'\n",
|
|
" OPENROUTER = 'OpenRouter'\n",
|
|
"\n",
|
|
"clients: dict[Provider, OpenAI] = {}\n",
|
|
"\n",
|
|
"if api_key := get_secret('OLLAMA_API_KEY'):\n",
|
|
" clients[Provider.OLLAMA] = OpenAI(api_key=api_key, base_url='https://ollama.com/v1')\n",
|
|
"\n",
|
|
"model = 'qwen3-coder:480b-cloud'\n",
|
|
"client = clients.get(Provider.OLLAMA)\n",
|
|
"if not client:\n",
|
|
" raise Exception('No client found')"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "Kq-AKZEjqnTp"
|
|
},
|
|
"source": [
|
|
"## Tasks"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"id": "fTHvG2w0sgwU"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"class Task(StrEnum):\n",
|
|
" COMMENTS = 'Comments'\n",
|
|
" UNIT_TESTS = 'Unit Tests'\n",
|
|
" FIX_CODE = 'Fix Code'\n",
|
|
" EXPLAIN = 'Explain'\n",
|
|
"\n",
|
|
"\n",
|
|
"def perform_tasks(tasks, code):\n",
|
|
" logger.info(f'Performing tasks: {tasks}')\n",
|
|
"\n",
|
|
" steps = []\n",
|
|
" if Task.COMMENTS in tasks:\n",
|
|
" steps.append('Add documentation comments to the given code. If the method name and parameters are self-explanatory, skip those comments.')\n",
|
|
" if Task.UNIT_TESTS in tasks:\n",
|
|
" steps.append('Add a thorough unit tests considering all edge cases to the given code.')\n",
|
|
" if Task.FIX_CODE in tasks:\n",
|
|
" steps.append('You are to fix the given code, if it has any issues.')\n",
|
|
" if Task.EXPLAIN in tasks:\n",
|
|
" steps.append('Explain the given code.')\n",
|
|
"\n",
|
|
" system_prompt = f'''\n",
|
|
" You are an experienced polyglot software engineer and given a code you can\n",
|
|
" detect what programming language it is in.\n",
|
|
" DO NOT fix the code until expressly told to do so.\n",
|
|
"\n",
|
|
" Your tasks:\n",
|
|
" {'- ' + '\\n- '.join(steps)}\n",
|
|
" '''\n",
|
|
" messages = [\n",
|
|
" {\"role\": \"system\", \"content\": system_prompt},\n",
|
|
" {\"role\": \"user\", \"content\": f'Code: \\n{code}'}\n",
|
|
" ]\n",
|
|
" response = client.chat.completions.create(\n",
|
|
" model=model,\n",
|
|
" messages=messages\n",
|
|
" )\n",
|
|
"\n",
|
|
" content = response.choices[0].message.content\n",
|
|
"\n",
|
|
" return content"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "SkmMYw_osxeG"
|
|
},
|
|
"source": [
|
|
"### Examples"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"id": "nlzUyXFus0km"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"def get_examples() -> tuple[list[any], list[str]]:\n",
|
|
" '''Returns examples and their labels'''\n",
|
|
"\n",
|
|
" # Python examples\n",
|
|
" add = r'''\n",
|
|
" def add(a, b):\n",
|
|
" return a + b\n",
|
|
" '''\n",
|
|
"\n",
|
|
" multiply = r'''\n",
|
|
" def multiply(a, b):\n",
|
|
" return a * b\n",
|
|
" '''\n",
|
|
"\n",
|
|
" divide = r'''\n",
|
|
" def divide(a, b):\n",
|
|
" return a / b\n",
|
|
" '''\n",
|
|
"\n",
|
|
" # JavaScript example - async function\n",
|
|
" fetch_data = r'''\n",
|
|
" async function fetchUserData(userId) {\n",
|
|
" const response = await fetch(`/api/users/${userId}`);\n",
|
|
" const data = await response.json();\n",
|
|
" return data;\n",
|
|
" }\n",
|
|
" '''\n",
|
|
"\n",
|
|
" # Java example - sorting algorithm\n",
|
|
" bubble_sort = r'''\n",
|
|
" public void bubbleSort(int[] arr) {\n",
|
|
" int n = arr.length;\n",
|
|
" for (int i = 0; i < n-1; i++) {\n",
|
|
" for (int j = 0; j < n-i-1; j++) {\n",
|
|
" if (arr[j] > arr[j+1]) {\n",
|
|
" int temp = arr[j];\n",
|
|
" arr[j] = arr[j+1];\n",
|
|
" arr[j+1] = temp;\n",
|
|
" }\n",
|
|
" }\n",
|
|
" }\n",
|
|
" }\n",
|
|
" '''\n",
|
|
"\n",
|
|
" # C++ example - buggy pointer code\n",
|
|
" buggy_cpp = r'''\n",
|
|
" int* createArray() {\n",
|
|
" int arr[5] = {1, 2, 3, 4, 5};\n",
|
|
" return arr;\n",
|
|
" }\n",
|
|
" '''\n",
|
|
"\n",
|
|
" # Rust example - ownership puzzle\n",
|
|
" rust_ownership = r'''\n",
|
|
" fn main() {\n",
|
|
" let s1 = String::from(\"hello\");\n",
|
|
" let s2 = s1;\n",
|
|
" println!(\"{}\", s1);\n",
|
|
" }\n",
|
|
" '''\n",
|
|
"\n",
|
|
" # Go example - concurrent code\n",
|
|
" go_concurrent = r'''\n",
|
|
" func processData(data []int) int {\n",
|
|
" sum := 0\n",
|
|
" for _, v := range data {\n",
|
|
" sum += v\n",
|
|
" }\n",
|
|
" return sum\n",
|
|
" }\n",
|
|
" '''\n",
|
|
"\n",
|
|
" # TypeScript example - complex type\n",
|
|
" ts_generics = r'''\n",
|
|
" function mergeObjects<T, U>(obj1: T, obj2: U): T & U {\n",
|
|
" return { ...obj1, ...obj2 };\n",
|
|
" }\n",
|
|
" '''\n",
|
|
"\n",
|
|
" # Ruby example - metaclass magic\n",
|
|
" ruby_meta = r'''\n",
|
|
" class DynamicMethod\n",
|
|
" define_method(:greet) do |name|\n",
|
|
" \"Hello, #{name}!\"\n",
|
|
" end\n",
|
|
" end\n",
|
|
" '''\n",
|
|
"\n",
|
|
" # PHP example - SQL injection vulnerable\n",
|
|
" php_vulnerable = r'''\n",
|
|
" function getUser($id) {\n",
|
|
" $query = \"SELECT * FROM users WHERE id = \" . $id;\n",
|
|
" return mysqli_query($conn, $query);\n",
|
|
" }\n",
|
|
" '''\n",
|
|
"\n",
|
|
" # Python example - complex algorithm\n",
|
|
" binary_search = r'''\n",
|
|
" def binary_search(arr, target):\n",
|
|
" left, right = 0, len(arr) - 1\n",
|
|
" while left <= right:\n",
|
|
" mid = (left + right) // 2\n",
|
|
" if arr[mid] == target:\n",
|
|
" return mid\n",
|
|
" elif arr[mid] < target:\n",
|
|
" left = mid + 1\n",
|
|
" else:\n",
|
|
" right = mid - 1\n",
|
|
" return -1\n",
|
|
" '''\n",
|
|
"\n",
|
|
" # JavaScript example - closure concept\n",
|
|
" js_closure = r'''\n",
|
|
" function counter() {\n",
|
|
" let count = 0;\n",
|
|
" return function() {\n",
|
|
" count++;\n",
|
|
" return count;\n",
|
|
" };\n",
|
|
" }\n",
|
|
" '''\n",
|
|
"\n",
|
|
" examples = [\n",
|
|
" # Simple Python examples\n",
|
|
" [[Task.COMMENTS], add, 'python'],\n",
|
|
" [[Task.UNIT_TESTS], multiply, 'python'],\n",
|
|
" [[Task.COMMENTS, Task.FIX_CODE], divide, 'python'],\n",
|
|
"\n",
|
|
" # Explain complex concepts\n",
|
|
" [[Task.EXPLAIN], binary_search, 'python'],\n",
|
|
" [[Task.EXPLAIN], js_closure, 'javascript'],\n",
|
|
" [[Task.EXPLAIN], rust_ownership, 'rust'],\n",
|
|
"\n",
|
|
" # Unit tests for different languages\n",
|
|
" [[Task.UNIT_TESTS], fetch_data, 'javascript'],\n",
|
|
" [[Task.UNIT_TESTS], go_concurrent, 'go'],\n",
|
|
"\n",
|
|
" # Fix buggy code\n",
|
|
" [[Task.FIX_CODE], buggy_cpp, 'cpp'],\n",
|
|
" [[Task.FIX_CODE], php_vulnerable, 'php'],\n",
|
|
"\n",
|
|
" # Multi-task combinations\n",
|
|
" [[Task.COMMENTS, Task.EXPLAIN], bubble_sort, None],\n",
|
|
" [[Task.COMMENTS, Task.UNIT_TESTS], ts_generics, 'typescript'],\n",
|
|
" [[Task.EXPLAIN, Task.FIX_CODE], rust_ownership, 'rust'],\n",
|
|
" [[Task.COMMENTS, Task.UNIT_TESTS, Task.EXPLAIN], ruby_meta, 'ruby'],\n",
|
|
" ]\n",
|
|
"\n",
|
|
" example_labels = [\n",
|
|
" '🐍 Python: Add Function',\n",
|
|
" '🐍 Python: Multiply Tests',\n",
|
|
" '🐍 Python: Fix Division',\n",
|
|
" '🐍 Python: Binary Search Explained',\n",
|
|
" '🟨 JavaScript: Closure Concept',\n",
|
|
" '🦀 Rust: Ownership Puzzle',\n",
|
|
" '🟨 JavaScript: Async Test',\n",
|
|
" '🐹 Go: Concurrency Test',\n",
|
|
" '⚡ C++: Fix Pointer Bug',\n",
|
|
" '🐘 PHP: Fix SQL Injection',\n",
|
|
" '☕ Java: Bubble Sort Guide',\n",
|
|
" '📘 TypeScript: Generics & Tests',\n",
|
|
" '🦀 Rust: Fix & Explain Ownership',\n",
|
|
" '💎 Ruby: Meta Programming Deep Dive',\n",
|
|
" ]\n",
|
|
"\n",
|
|
" return examples, example_labels"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "wYReYuvgtDgg"
|
|
},
|
|
"source": [
|
|
"## Gradio UI\n",
|
|
"\n",
|
|
"[Documentation](https://www.gradio.app/docs/gradio)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"colab": {
|
|
"base_uri": "https://localhost:8080/",
|
|
"height": 664
|
|
},
|
|
"id": "I8Q08SJe8CxK",
|
|
"outputId": "f1d41d06-dfda-4daf-b7ff-6f73bdaf8369"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"title = 'Snippet Sniper 🎯'\n",
|
|
"\n",
|
|
"with gr.Blocks(title=title, theme=gr.themes.Monochrome()) as ui:\n",
|
|
" gr.Markdown(f'# {title}')\n",
|
|
" gr.Markdown('## I am your [**John Wick**](https://en.wikipedia.org/wiki/John_Wick), ready to accept any contract on your code. Consider it executed 🎯🔫!.')\n",
|
|
"\n",
|
|
" with gr.Row():\n",
|
|
" with gr.Column():\n",
|
|
" tasks = gr.Dropdown(\n",
|
|
" label=\"Tasks\",\n",
|
|
" choices=[task.value for task in Task],\n",
|
|
" value=Task.COMMENTS,\n",
|
|
" multiselect=True,\n",
|
|
" interactive=True,\n",
|
|
" )\n",
|
|
" code_input = gr.Code(\n",
|
|
" label='Code Input',\n",
|
|
" lines=40,\n",
|
|
" )\n",
|
|
" code_language = gr.Textbox(visible=False)\n",
|
|
"\n",
|
|
" with gr.Column():\n",
|
|
" gr.Markdown('## Kill Zone 🧟🧠💀')\n",
|
|
" code_output = gr.Markdown('💣')\n",
|
|
"\n",
|
|
"\n",
|
|
" run_btn = gr.Button('📜 Issue Contract')\n",
|
|
"\n",
|
|
" def set_language(tasks, code, language):\n",
|
|
" syntax_highlights = ['python', 'c', 'cpp', 'javascript', 'typescript']\n",
|
|
" logger.debug(f'Tasks: {tasks}, Languge: {language}')\n",
|
|
" highlight = language if language in syntax_highlights else None\n",
|
|
"\n",
|
|
" return tasks, gr.Code(value=code, language=highlight)\n",
|
|
"\n",
|
|
" examples, example_labels = get_examples()\n",
|
|
" examples = gr.Examples(\n",
|
|
" examples=examples,\n",
|
|
" example_labels=example_labels,\n",
|
|
" examples_per_page=20,\n",
|
|
" inputs=[tasks, code_input, code_language],\n",
|
|
" outputs=[tasks, code_input],\n",
|
|
" run_on_click=True,\n",
|
|
" fn=set_language\n",
|
|
" )\n",
|
|
"\n",
|
|
" run_btn.click(perform_tasks, inputs=[tasks, code_input], outputs=[code_output])\n",
|
|
"\n",
|
|
"ui.launch(debug=True)"
|
|
]
|
|
}
|
|
],
|
|
"metadata": {
|
|
"colab": {
|
|
"provenance": []
|
|
},
|
|
"kernelspec": {
|
|
"display_name": "Python 3",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"name": "python"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 0
|
|
}
|