diff --git a/community-contributions/abdoul/week_three_exercise.ipynb b/community-contributions/abdoul/week_three_exercise.ipynb new file mode 100644 index 0000000..6157e68 --- /dev/null +++ b/community-contributions/abdoul/week_three_exercise.ipynb @@ -0,0 +1,1767 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 22, + "id": "3f7540d9", + "metadata": { + "id": "3f7540d9" + }, + "outputs": [], + "source": [ + "import os\n", + "import requests\n", + "from IPython.display import Markdown, display, update_display\n", + "from openai import OpenAI\n", + "from google.colab import drive\n", + "from huggingface_hub import login\n", + "from google.colab import userdata\n", + "from transformers import AutoTokenizer, AutoModelForCausalLM, TextStreamer, BitsAndBytesConfig\n", + "import torch\n", + "\n", + "from functools import lru_cache\n", + "from diffusers import StableDiffusionPipeline\n", + "import gradio as gr" + ] + }, + { + "cell_type": "code", + "source": [ + "hf_token = userdata.get('HF_TOKEN')\n", + "login(hf_token, add_to_git_credential=True)" + ], + "metadata": { + "id": "BX0nP9tyGG6S" + }, + "id": "BX0nP9tyGG6S", + "execution_count": 23, + "outputs": [] + }, + { + "cell_type": "code", + "execution_count": 24, + "id": "27c5d024", + "metadata": { + "id": "27c5d024" + }, + "outputs": [], + "source": [ + "TEXT_MODEL_ID = \"meta-llama/Llama-3.2-3B-Instruct\"\n", + "IMAGE_MODEL_ID = \"runwayml/stable-diffusion-v1-5\"\n", + "\n", + "FORMAT_RULES = {\n", + " \"JSON\": \"Return a JSON array containing records_count objects with consistent fields tailored to the context.\",\n", + " \"CSV\": \"Return a CSV document with a header row and records_count data rows aligned to the context.\",\n", + " \"Raw Text\": \"Return records_count prose entries separated by blank lines that reflect the context.\",\n", + " \"Code\": \"Return records_count code snippets grouped in a single fenced block that models the context.\"\n", + "}\n", + "\n", + "@lru_cache(maxsize=2)\n", + "def load_text_components(use_quant: bool):\n", + " tokenizer = AutoTokenizer.from_pretrained(TEXT_MODEL_ID)\n", + " if tokenizer.pad_token is None:\n", + " tokenizer.pad_token = tokenizer.eos_token\n", + " if use_quant:\n", + " quant_config = BitsAndBytesConfig(\n", + " load_in_4bit=True,\n", + " bnb_4bit_use_double_quant=True,\n", + " bnb_4bit_compute_dtype=torch.bfloat16,\n", + " bnb_4bit_quant_type=\"nf4\"\n", + " )\n", + " model = AutoModelForCausalLM.from_pretrained(\n", + " TEXT_MODEL_ID,\n", + " device_map=\"auto\",\n", + " quantization_config=quant_config,\n", + " trust_remote_code=True\n", + " )\n", + " else:\n", + " kwargs = {\"trust_remote_code\": True}\n", + " kwargs[\"device_map\"] = \"auto\"\n", + " kwargs[\"torch_dtype\"] = torch.float16\n", + "\n", + " model = AutoModelForCausalLM.from_pretrained(TEXT_MODEL_ID, **kwargs)\n", + "\n", + " model.eval()\n", + " return tokenizer, model\n", + "\n", + "def build_text_messages(style: str, context: str, return_format: str, record_count: int):\n", + " context_value = context.strip() if context else \"general purpose scenario\"\n", + " style_value = style.strip() if style else \"Balanced\"\n", + " directive = FORMAT_RULES[return_format]\n", + " system_prompt = \"You generate synthetic datasets that are high quality, diverse, and free of personally identifiable information. \" + directive + \" Ensure outputs are consistent in structure, imaginative in content, and avoid explanations.\"\n", + " user_prompt = f\"Context: {context_value}\\nStyle: {style_value}\\nRecords: {record_count}\\nOutput style: {return_format}\"\n", + " return [\n", + " {\"role\": \"system\", \"content\": system_prompt},\n", + " {\"role\": \"user\", \"content\": user_prompt}\n", + " ]\n", + "\n", + "def generate_text_data(style: str, context: str, return_format: str, quantize: bool, record_count: int):\n", + " tokenizer, model = load_text_components(bool(quantize))\n", + " messages = build_text_messages(style, context, return_format, int(record_count))\n", + " inputs = tokenizer.apply_chat_template(messages, return_tensors=\"pt\", add_generation_prompt=True)\n", + " inputs = inputs.to(\"cuda\")\n", + " attention_mask = torch.ones_like(inputs)\n", + " with torch.inference_mode():\n", + " generated = model.generate(\n", + " input_ids=inputs,\n", + " attention_mask=attention_mask,\n", + " max_new_tokens=512,\n", + " temperature=0.7,\n", + " top_p=0.9,\n", + " repetition_penalty=1.05,\n", + " do_sample=True\n", + " )\n", + "\n", + " output_ids = generated[:, inputs.shape[-1]:]\n", + " text = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0]\n", + " return text.strip()\n", + "\n", + "@lru_cache(maxsize=1)\n", + "def load_image_pipeline():\n", + " pipeline = StableDiffusionPipeline.from_pretrained(IMAGE_MODEL_ID, torch_dtype=torch.float16)\n", + " pipeline = pipeline.to(\"cuda\")\n", + " return pipeline\n", + "\n", + "def generate_image_data(style: str, context: str, image_prompt: str, image_count: int):\n", + " pipeline = load_image_pipeline()\n", + " parts = []\n", + " if image_prompt:\n", + " parts.append(image_prompt.strip())\n", + "\n", + " if context:\n", + " parts.append(context.strip())\n", + "\n", + " base = \", \".join([p for p in parts if p])\n", + " if not base:\n", + " base = \"Synthetic data concept visualization\"\n", + "\n", + " prompt = f\"{base}, {style.lower()} style\"\n", + " images = pipeline(prompt, num_images_per_prompt=int(image_count), guidance_scale=7.0, num_inference_steps=30).images\n", + " return images\n", + "\n", + "def run_generation(data_type: str, style: str, context: str, return_format: str, quantize: bool, image_prompt: str, record_count: int, image_count: int):\n", + " if data_type == \"Text\":\n", + " text = generate_text_data(style, context, return_format, quantize, record_count)\n", + " return gr.update(value=text, visible=True), gr.update(value=[], visible=False)\n", + "\n", + " images = generate_image_data(style, context, image_prompt, image_count)\n", + " return gr.update(value=\"\", visible=False), gr.update(value=images, visible=True)\n", + "\n", + "def toggle_inputs(data_type: str):\n", + " if data_type == \"Text\":\n", + " return (\n", + " gr.update(visible=True),\n", + " gr.update(visible=True),\n", + " gr.update(visible=False),\n", + " gr.update(visible=True),\n", + " gr.update(visible=False),\n", + " gr.update(value=\"\", visible=True),\n", + " gr.update(value=[], visible=False)\n", + " )\n", + " return (\n", + " gr.update(visible=False),\n", + " gr.update(visible=False),\n", + " gr.update(visible=True),\n", + " gr.update(visible=False),\n", + " gr.update(visible=True),\n", + " gr.update(value=\"\", visible=False),\n", + " gr.update(value=[], visible=True)\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "id": "3d1c45e6", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 919, + "referenced_widgets": [ + "4c8d7ca79cc74a94a98b0713b20dfe8f", + "6b9303db835d4a89b0b8dc61d122cee2", + "ef3165923086493cb77cf0884e32e004", + "141b5765befc4dabb34cb078da81293c", + "70170fe59f9f4c4c8d4c31669b4231d6", + "55574da55e0a4ac295f4a979a91c0eb8", + "8db965e763e246bba8463cbf37ab2266", + "3417a8ea7eba45dcb1a3504b34111619", + "8e7ef80834774a779247865809b88ee1", + "52570f3f9f3242d9891454331f20d321", + "0496daff34c546dfba5f1aab3cada450", + "91840a9230c64ec09a9b7211c547ac30", + "83fa6555742847babffb2a0396f704b5", + "ab9c5275104d4fb48a4cdfe820f85414", + "31228bbd323f4afa9b9cf9c300d45554", + "2a036f8bb20547d3a5b5b770fada56ae", + "4525e8fefbcc42c892906f7a035dafdf", + "fc65dfc49f074695a4a30a38c302ac7e", + "6855bec8f66143fd8c2554e7b6b0020e", + "1469789da1884215a3e97a1e56d5da5d", + "06d95dacb31e4e41a2b68340b9c77fc6", + "04480cc4ac134b76b7b726706db73bef", + "192a0110e8064f56ae288e0781fa4583", + "59ffe392a11b40aab9138d5b64d32297", + "701777b6e2cb49699bd03b22710b5b6f", + "295ad85256ba4e3db613067c5d1c21c3", + "b25546e1ad2b4ab29d33646800eb1270", + "aa072bf87a064282ad50649cf399ac29", + "13fb3056e0c64646aa89731e48c24659", + "1380b97fd02d4b33b08b93cb52e73325", + "38d3940d694241379cafb30a9f290a2d", + "e4241f1533364f808c46f9b5ef4c8c23", + "38dd79b5e8f04c7c86e6b7825719498a", + "a895f63ec6e0486eaf253775142e0778", + "da79639a14d14766980360ba6e20b7ed", + "519cd112cb164744aa0c4a5ac43eedf4", + "a4c22f140aeb44d39b05cbcda1b7d357", + "b08fa2ac551a4fff9fd909f5b49d4ceb", + "97c4d462448049a4a464d8f979672f97", + "9ab6bb6a789d4ccbbd9eee468056ba9b", + "26b3c420785d436e8ac334df2f97d28a", + "8267fba3eac04f43aa22ec3f69731841", + "4e417308ca3a41b4b1a8332cdbc4098f", + "a2066a8649584f8fab62e91c3d07e25e" + ] + }, + "id": "3d1c45e6", + "outputId": "483c396f-8876-4d59-db22-debe4b2bb2b8" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "It looks like you are running Gradio on a hosted Jupyter notebook, which requires `share=True`. Automatically setting `share=True` (you can turn this off by setting `share=False` in `launch()` explicitly).\n", + "\n", + "Colab notebook detected. This cell will run indefinitely so that you can see errors and logs. To turn off, set debug=False in launch().\n", + "* Running on public URL: https://b5fd391afd63f4968c.gradio.live\n", + "\n", + "This share link expires in 1 week. For free permanent hosting and GPU upgrades, run `gradio deploy` from the terminal in the working directory to deploy to Hugging Face Spaces (https://huggingface.co/spaces)\n" + ] + }, + { + "output_type": "display_data", + "data": { + "text/plain": [ + "" + ], + "text/html": [ + "
" + ] + }, + "metadata": {} + }, + { + "output_type": "display_data", + "data": { + "text/plain": [ + "Loading checkpoint shards: 0%| | 0/2 [00:00 https://b5fd391afd63f4968c.gradio.live\n" + ] + }, + { + "output_type": "execute_result", + "data": { + "text/plain": [] + }, + "metadata": {}, + "execution_count": 25 + } + ], + "source": [ + "with gr.Blocks(title=\"Synthetic Data Generator\") as demo:\n", + " gr.Markdown(\"## Synthetic Data Generator\")\n", + " with gr.Row():\n", + " data_type = gr.Radio([\"Text\", \"Image\"], label=\"Type\", value=\"Text\")\n", + " style = gr.Dropdown([\"Concise\", \"Detailed\", \"Narrative\", \"Technical\", \"Tabular\"], label=\"Style\", value=\"Detailed\")\n", + "\n", + " context_input = gr.Textbox(label=\"Context\", lines=4, placeholder=\"Describe the entities, attributes, and purpose of the dataset.\")\n", + " return_format = gr.Dropdown([\"JSON\", \"CSV\", \"Raw Text\", \"Code\"], label=\"Return Format\", value=\"JSON\")\n", + " quantize = gr.Checkbox(label=\"Quantize\", value=False)\n", + " record_count = gr.Slider(1, 20, value=5, step=1, label=\"Records\")\n", + " image_prompt = gr.Textbox(label=\"Image Prompt\", lines=2, visible=False, placeholder=\"Detail the visual you want to synthesize.\")\n", + " image_count = gr.Slider(1, 4, value=1, step=1, label=\"Images\", visible=False)\n", + " generate_button = gr.Button(\"Generate\")\n", + " text_output = gr.Textbox(label=\"Text Output\", lines=12)\n", + " image_output = gr.Gallery(label=\"Generated Images\", visible=False, columns=2, rows=1)\n", + " data_type.change(\n", + " toggle_inputs,\n", + " inputs=data_type,\n", + " outputs=[return_format, quantize, image_prompt, record_count, image_count, text_output, image_output]\n", + " )\n", + " generate_button.click(\n", + " run_generation,\n", + " inputs=[data_type, style, context_input, return_format, quantize, image_prompt, record_count, image_count],\n", + " outputs=[text_output, image_output]\n", + " )\n", + "\n", + "\n", + "demo.launch(debug=True)" + ] + } + ], + "metadata": { + "language_info": { + "name": "python" + }, + "colab": { + "provenance": [], + "history_visible": true, + "gpuType": "T4" + }, + "accelerator": "GPU", + "kernelspec": { + "name": "python3", + "display_name": "Python 3" + }, + "widgets": { + "application/vnd.jupyter.widget-state+json": { + "4c8d7ca79cc74a94a98b0713b20dfe8f": { + "model_module": "@jupyter-widgets/controls", + "model_name": "HBoxModel", + "model_module_version": "1.5.0", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HBoxModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HBoxView", + "box_style": "", + "children": [ + "IPY_MODEL_6b9303db835d4a89b0b8dc61d122cee2", + "IPY_MODEL_ef3165923086493cb77cf0884e32e004", + "IPY_MODEL_141b5765befc4dabb34cb078da81293c" + ], + "layout": "IPY_MODEL_70170fe59f9f4c4c8d4c31669b4231d6" + } + }, + "6b9303db835d4a89b0b8dc61d122cee2": { + "model_module": "@jupyter-widgets/controls", + "model_name": "HTMLModel", + "model_module_version": "1.5.0", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_55574da55e0a4ac295f4a979a91c0eb8", + "placeholder": "​", + "style": "IPY_MODEL_8db965e763e246bba8463cbf37ab2266", + "value": "Loading checkpoint shards: 100%" + } + }, + "ef3165923086493cb77cf0884e32e004": { + "model_module": "@jupyter-widgets/controls", + "model_name": "FloatProgressModel", + "model_module_version": "1.5.0", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "FloatProgressModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "ProgressView", + "bar_style": "success", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_3417a8ea7eba45dcb1a3504b34111619", + "max": 2, + "min": 0, + "orientation": "horizontal", + "style": "IPY_MODEL_8e7ef80834774a779247865809b88ee1", + "value": 2 + } + }, + "141b5765befc4dabb34cb078da81293c": { + "model_module": "@jupyter-widgets/controls", + "model_name": "HTMLModel", + "model_module_version": "1.5.0", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_52570f3f9f3242d9891454331f20d321", + "placeholder": "​", + "style": "IPY_MODEL_0496daff34c546dfba5f1aab3cada450", + "value": " 2/2 [00:25<00:00, 11.32s/it]" + } + }, + "70170fe59f9f4c4c8d4c31669b4231d6": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "model_module_version": "1.2.0", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "55574da55e0a4ac295f4a979a91c0eb8": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "model_module_version": "1.2.0", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "8db965e763e246bba8463cbf37ab2266": { + "model_module": "@jupyter-widgets/controls", + "model_name": "DescriptionStyleModel", + "model_module_version": "1.5.0", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "3417a8ea7eba45dcb1a3504b34111619": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "model_module_version": "1.2.0", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "8e7ef80834774a779247865809b88ee1": { + "model_module": "@jupyter-widgets/controls", + "model_name": "ProgressStyleModel", + "model_module_version": "1.5.0", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "ProgressStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "bar_color": null, + "description_width": "" + } + }, + "52570f3f9f3242d9891454331f20d321": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "model_module_version": "1.2.0", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "0496daff34c546dfba5f1aab3cada450": { + "model_module": "@jupyter-widgets/controls", + "model_name": "DescriptionStyleModel", + "model_module_version": "1.5.0", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "91840a9230c64ec09a9b7211c547ac30": { + "model_module": "@jupyter-widgets/controls", + "model_name": "HBoxModel", + "model_module_version": "1.5.0", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HBoxModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HBoxView", + "box_style": "", + "children": [ + "IPY_MODEL_83fa6555742847babffb2a0396f704b5", + "IPY_MODEL_ab9c5275104d4fb48a4cdfe820f85414", + "IPY_MODEL_31228bbd323f4afa9b9cf9c300d45554" + ], + "layout": "IPY_MODEL_2a036f8bb20547d3a5b5b770fada56ae" + } + }, + "83fa6555742847babffb2a0396f704b5": { + "model_module": "@jupyter-widgets/controls", + "model_name": "HTMLModel", + "model_module_version": "1.5.0", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_4525e8fefbcc42c892906f7a035dafdf", + "placeholder": "​", + "style": "IPY_MODEL_fc65dfc49f074695a4a30a38c302ac7e", + "value": "Loading pipeline components...: 100%" + } + }, + "ab9c5275104d4fb48a4cdfe820f85414": { + "model_module": "@jupyter-widgets/controls", + "model_name": "FloatProgressModel", + "model_module_version": "1.5.0", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "FloatProgressModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "ProgressView", + "bar_style": "success", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_6855bec8f66143fd8c2554e7b6b0020e", + "max": 7, + "min": 0, + "orientation": "horizontal", + "style": "IPY_MODEL_1469789da1884215a3e97a1e56d5da5d", + "value": 7 + } + }, + "31228bbd323f4afa9b9cf9c300d45554": { + "model_module": "@jupyter-widgets/controls", + "model_name": "HTMLModel", + "model_module_version": "1.5.0", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_06d95dacb31e4e41a2b68340b9c77fc6", + "placeholder": "​", + "style": "IPY_MODEL_04480cc4ac134b76b7b726706db73bef", + "value": " 7/7 [00:28<00:00,  5.58s/it]" + } + }, + "2a036f8bb20547d3a5b5b770fada56ae": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "model_module_version": "1.2.0", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "4525e8fefbcc42c892906f7a035dafdf": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "model_module_version": "1.2.0", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "fc65dfc49f074695a4a30a38c302ac7e": { + "model_module": "@jupyter-widgets/controls", + "model_name": "DescriptionStyleModel", + "model_module_version": "1.5.0", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "6855bec8f66143fd8c2554e7b6b0020e": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "model_module_version": "1.2.0", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "1469789da1884215a3e97a1e56d5da5d": { + "model_module": "@jupyter-widgets/controls", + "model_name": "ProgressStyleModel", + "model_module_version": "1.5.0", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "ProgressStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "bar_color": null, + "description_width": "" + } + }, + "06d95dacb31e4e41a2b68340b9c77fc6": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "model_module_version": "1.2.0", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "04480cc4ac134b76b7b726706db73bef": { + "model_module": "@jupyter-widgets/controls", + "model_name": "DescriptionStyleModel", + "model_module_version": "1.5.0", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "192a0110e8064f56ae288e0781fa4583": { + "model_module": "@jupyter-widgets/controls", + "model_name": "HBoxModel", + "model_module_version": "1.5.0", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HBoxModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HBoxView", + "box_style": "", + "children": [ + "IPY_MODEL_59ffe392a11b40aab9138d5b64d32297", + "IPY_MODEL_701777b6e2cb49699bd03b22710b5b6f", + "IPY_MODEL_295ad85256ba4e3db613067c5d1c21c3" + ], + "layout": "IPY_MODEL_b25546e1ad2b4ab29d33646800eb1270" + } + }, + "59ffe392a11b40aab9138d5b64d32297": { + "model_module": "@jupyter-widgets/controls", + "model_name": "HTMLModel", + "model_module_version": "1.5.0", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_aa072bf87a064282ad50649cf399ac29", + "placeholder": "​", + "style": "IPY_MODEL_13fb3056e0c64646aa89731e48c24659", + "value": "100%" + } + }, + "701777b6e2cb49699bd03b22710b5b6f": { + "model_module": "@jupyter-widgets/controls", + "model_name": "FloatProgressModel", + "model_module_version": "1.5.0", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "FloatProgressModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "ProgressView", + "bar_style": "success", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_1380b97fd02d4b33b08b93cb52e73325", + "max": 30, + "min": 0, + "orientation": "horizontal", + "style": "IPY_MODEL_38d3940d694241379cafb30a9f290a2d", + "value": 30 + } + }, + "295ad85256ba4e3db613067c5d1c21c3": { + "model_module": "@jupyter-widgets/controls", + "model_name": "HTMLModel", + "model_module_version": "1.5.0", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_e4241f1533364f808c46f9b5ef4c8c23", + "placeholder": "​", + "style": "IPY_MODEL_38dd79b5e8f04c7c86e6b7825719498a", + "value": " 30/30 [00:05<00:00,  6.72it/s]" + } + }, + "b25546e1ad2b4ab29d33646800eb1270": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "model_module_version": "1.2.0", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "aa072bf87a064282ad50649cf399ac29": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "model_module_version": "1.2.0", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "13fb3056e0c64646aa89731e48c24659": { + "model_module": "@jupyter-widgets/controls", + "model_name": "DescriptionStyleModel", + "model_module_version": "1.5.0", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "1380b97fd02d4b33b08b93cb52e73325": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "model_module_version": "1.2.0", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "38d3940d694241379cafb30a9f290a2d": { + "model_module": "@jupyter-widgets/controls", + "model_name": "ProgressStyleModel", + "model_module_version": "1.5.0", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "ProgressStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "bar_color": null, + "description_width": "" + } + }, + "e4241f1533364f808c46f9b5ef4c8c23": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "model_module_version": "1.2.0", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "38dd79b5e8f04c7c86e6b7825719498a": { + "model_module": "@jupyter-widgets/controls", + "model_name": "DescriptionStyleModel", + "model_module_version": "1.5.0", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "a895f63ec6e0486eaf253775142e0778": { + "model_module": "@jupyter-widgets/controls", + "model_name": "HBoxModel", + "model_module_version": "1.5.0", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HBoxModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HBoxView", + "box_style": "", + "children": [ + "IPY_MODEL_da79639a14d14766980360ba6e20b7ed", + "IPY_MODEL_519cd112cb164744aa0c4a5ac43eedf4", + "IPY_MODEL_a4c22f140aeb44d39b05cbcda1b7d357" + ], + "layout": "IPY_MODEL_b08fa2ac551a4fff9fd909f5b49d4ceb" + } + }, + "da79639a14d14766980360ba6e20b7ed": { + "model_module": "@jupyter-widgets/controls", + "model_name": "HTMLModel", + "model_module_version": "1.5.0", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_97c4d462448049a4a464d8f979672f97", + "placeholder": "​", + "style": "IPY_MODEL_9ab6bb6a789d4ccbbd9eee468056ba9b", + "value": "100%" + } + }, + "519cd112cb164744aa0c4a5ac43eedf4": { + "model_module": "@jupyter-widgets/controls", + "model_name": "FloatProgressModel", + "model_module_version": "1.5.0", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "FloatProgressModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "ProgressView", + "bar_style": "success", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_26b3c420785d436e8ac334df2f97d28a", + "max": 30, + "min": 0, + "orientation": "horizontal", + "style": "IPY_MODEL_8267fba3eac04f43aa22ec3f69731841", + "value": 30 + } + }, + "a4c22f140aeb44d39b05cbcda1b7d357": { + "model_module": "@jupyter-widgets/controls", + "model_name": "HTMLModel", + "model_module_version": "1.5.0", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_4e417308ca3a41b4b1a8332cdbc4098f", + "placeholder": "​", + "style": "IPY_MODEL_a2066a8649584f8fab62e91c3d07e25e", + "value": " 30/30 [00:04<00:00,  6.89it/s]" + } + }, + "b08fa2ac551a4fff9fd909f5b49d4ceb": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "model_module_version": "1.2.0", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "97c4d462448049a4a464d8f979672f97": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "model_module_version": "1.2.0", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "9ab6bb6a789d4ccbbd9eee468056ba9b": { + "model_module": "@jupyter-widgets/controls", + "model_name": "DescriptionStyleModel", + "model_module_version": "1.5.0", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "26b3c420785d436e8ac334df2f97d28a": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "model_module_version": "1.2.0", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "8267fba3eac04f43aa22ec3f69731841": { + "model_module": "@jupyter-widgets/controls", + "model_name": "ProgressStyleModel", + "model_module_version": "1.5.0", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "ProgressStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "bar_color": null, + "description_width": "" + } + }, + "4e417308ca3a41b4b1a8332cdbc4098f": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "model_module_version": "1.2.0", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "a2066a8649584f8fab62e91c3d07e25e": { + "model_module": "@jupyter-widgets/controls", + "model_name": "DescriptionStyleModel", + "model_module_version": "1.5.0", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + } + } + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} \ No newline at end of file diff --git a/community-contributions/muhammad_qasim_sheikh/Week 1/Day 1/ReadME.md b/community-contributions/muhammad_qasim_sheikh/Week 1/Day 1/ReadME.md deleted file mode 100644 index 0d1952a..0000000 --- a/community-contributions/muhammad_qasim_sheikh/Week 1/Day 1/ReadME.md +++ /dev/null @@ -1,50 +0,0 @@ -# **Automated Bitcoin Daily Summary Generator** - -This project automates the process of generating a daily summary of the Bitcoin network's status. It fetches real-time data from multiple public API endpoints, processes it, and then uses a Large Language Model (LLM) to generate a clear, structured, and human-readable report in Markdown format. - -## **Project Overview** - -The core goal of this project is to provide a snapshot of key Bitcoin metrics without manual analysis. By leveraging the Braiins Public API for data and OpenAI's GPT models for summarization, it can produce insightful daily reports covering market trends, network health, miner revenue, and future outlooks like the next halving event. - -### **Key Features** - -- **Automated Data Fetching**: Pulls data from 7 different Braiins API endpoints covering price, hashrate, difficulty, transaction fees, and more. -- **Data Cleaning**: Pre-processes the raw JSON data to make it clean and suitable for the LLM. -- **Intelligent Summarization**: Uses an advanced LLM to analyze the data and generate a structured report with explanations for technical terms. -- **Dynamic Dating**: The report is always dated for the day it is run, providing a timely summary regardless of the timestamps in the source data. -- **Markdown Output**: Generates a clean, well-formatted Markdown file that is easy to read or integrate into other systems. - -## **How It Works** - -The project is split into two main files: - -1. **utils.py**: A utility script responsible for all data fetching and cleaning operations. - - It defines the Braiins API endpoints to be queried. - - It contains functions to handle HTTP requests, parse JSON responses, and clean up keys and values to ensure consistency. -2. **day_1_bitcoin_daily_brief.ipynb**: A Jupyter Notebook that acts as the main orchestrator. - - It imports the necessary functions from utils.py. - - It calls fetch_clean_data() to get the latest Bitcoin network data. - - It constructs a detailed system and user prompt for the LLM, explicitly instructing it on the desired format and, crucially, to use the current date for the summary. - - It sends the data and prompt to the OpenAI API. - - It receives the generated summary and displays it as formatted Markdown. - -## **Setup and Usage** - -To run this project, you will need to have Python and the required libraries installed. - -### **1\. Prerequisites** - -- Python 3.x -- Jupyter Notebook or JupyterLab - -### **2\. Installation** - -- Install the necessary Python libraries: pip install requests openai python-dotenv jupyter - -### **3\. Configuration** - -You need an API key from OpenAI to use the summarization feature. - -1. Create a file named .env in the root directory of the project. -2. Add your OpenAI API key to the .env file as follows: - OPENAI_API_KEY='your_openai_api_key_here' diff --git a/community-contributions/muhammad_qasim_sheikh/Week 1/Day 1/day_1_bitcoin_daily_brief.ipynb b/community-contributions/muhammad_qasim_sheikh/Week 1/Day 1/day_1_bitcoin_daily_brief.ipynb deleted file mode 100644 index b99d8b5..0000000 --- a/community-contributions/muhammad_qasim_sheikh/Week 1/Day 1/day_1_bitcoin_daily_brief.ipynb +++ /dev/null @@ -1,156 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "id": "abaef96b", - "metadata": {}, - "source": [ - "## Importing The Libraries" - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "id": "f90c541b", - "metadata": {}, - "outputs": [], - "source": [ - "import os\n", - "import datetime\n", - "from utils import fetch_clean_data\n", - "from openai import OpenAI\n", - "from IPython.display import Markdown, display\n", - "from dotenv import load_dotenv\n", - "import json" - ] - }, - { - "cell_type": "markdown", - "id": "6e6c864b", - "metadata": {}, - "source": [ - "## Configuration" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "id": "be62299d", - "metadata": {}, - "outputs": [], - "source": [ - "load_dotenv(override=True)\n", - "api_key = os.getenv('OPENAI_API_KEY')\n", - "\n", - "client = OpenAI()" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "id": "3aa8e3e2", - "metadata": {}, - "outputs": [], - "source": [ - "def generate_markdown_summary(data: dict, today_date_str: str) -> str:\n", - " \"\"\"\n", - " Send cleaned Bitcoin data to an LLM and receive a Markdown summary.\n", - " \"\"\"\n", - "\n", - " system_prompt = f\"\"\"\n", - " You are a professional crypto analyst. Your job is to read the provided Bitcoin network data \n", - " and write a clear, structured report that can be read directly as a daily summary.\n", - "\n", - " Following are the rules that you must adhere to:\n", - " - **IMPORTANT**: The summary title MUST use today's date: {today_date_str}. The title must be: \"Bitcoin Daily Summary - {today_date_str}\".\n", - " - **CRITICAL**: Do NOT infer the reporting period from the data. The data contains historical records, but your report is for {today_date_str}.\n", - " - Include **headings** for sections like \"Market Overview\", \"Network Metrics Explained\", \"Miner Revenue Trends\", and \"Halving Outlook\".\n", - " - Use **bullet points** for key metrics.\n", - " - Use a **table** for historical or time-series data if available.\n", - " - Explain important terms (like hashrate, difficulty, transaction fees) in plain language.\n", - "\n", - " Respond in markdown. Do not wrap the markdown in a code block - respond just with the markdown.\n", - " \"\"\"\n", - "\n", - " # Convert the Python data dictionary into a clean JSON string for the prompt\n", - " data_str = json.dumps(data, indent=2)\n", - "\n", - " user_prompt = f\"\"\"\n", - " Today's date is {today_date_str}. Use this as the reference point for the report.\n", - "\n", - " The following data may contain historical records (e.g., from 2024), \n", - " but you must treat it as background context and write the summary as of {today_date_str}.\n", - "\n", - " Here is the data for you to summarize: \n", - " {data_str}\n", - " \"\"\"\n", - " \n", - " response = client.chat.completions.create(\n", - " model= \"gpt-4.1-mini\", \n", - " messages=[\n", - " {\"role\": \"system\", \"content\": system_prompt},\n", - " {\"role\": \"user\", \"content\": user_prompt}\n", - " ]\n", - " )\n", - "\n", - " markdown_text = response.choices[0].message.content.strip()\n", - " return markdown_text" - ] - }, - { - "cell_type": "markdown", - "id": "1e8c2d7d", - "metadata": {}, - "source": [ - "## Main Function" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "05059ed9", - "metadata": {}, - "outputs": [], - "source": [ - "def main():\n", - " # 0. Get today's date as a string\n", - " today_str = datetime.datetime.now().strftime('%B %d, %Y')\n", - " \n", - " # 1. Fetch and clean data\n", - " print(\"Fetching Bitcoin data...\")\n", - " data = fetch_clean_data()\n", - "\n", - " # 2. Generate Markdown summary\n", - " print(\"Generating LLM summary...\")\n", - " markdown_report = generate_markdown_summary(data, today_str)\n", - "\n", - " # 3. Display Output\n", - " display(Markdown(markdown_report))\n", - "\n", - "if __name__ == \"__main__\":\n", - " main()" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "llm-engineering", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.12.12" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} diff --git a/community-contributions/muhammad_qasim_sheikh/Week 1/Day 1/utils.py b/community-contributions/muhammad_qasim_sheikh/Week 1/Day 1/utils.py deleted file mode 100644 index 7371374..0000000 --- a/community-contributions/muhammad_qasim_sheikh/Week 1/Day 1/utils.py +++ /dev/null @@ -1,121 +0,0 @@ -# utils.py - -import requests -import re -import datetime -import logging -from typing import Dict, Optional, Union - -# ----------------------------------------- -# Logging setup -# ----------------------------------------- -logging.basicConfig(level=logging.INFO) -logger = logging.getLogger(__name__) - -# ----------------------------------------- -# Braiins API endpoints (7 selected) -# ----------------------------------------- -BRAIINS_APIS = { - 'price_stats': 'https://insights.braiins.com/api/v1.0/price-stats', - 'hashrate_stats': 'https://insights.braiins.com/api/v1.0/hashrate-stats', - 'difficulty_stats': 'https://insights.braiins.com/api/v1.0/difficulty-stats', - 'transaction_fees_history': 'https://insights.braiins.com/api/v1.0/transaction-fees-history', - 'daily_revenue_history': 'https://insights.braiins.com/api/v1.0/daily-revenue-history', - 'hashrate_value_history': 'https://insights.braiins.com/api/v1.0/hashrate-value-history', - 'halvings': 'https://insights.braiins.com/api/v2.0/halvings' -} - - -# ----------------------------------------- -# Utility Functions -# ----------------------------------------- -def clean_value(value): - """Clean strings, remove brackets/quotes and standardize whitespace.""" - if value is None: - return "" - s = str(value) - s = s.replace(",", " ") - s = re.sub(r"[\[\]\{\}\(\)]", "", s) - s = s.replace('"', "").replace("'", "") - s = re.sub(r"\s+", " ", s) - return s.strip() - - -def parse_date(date_str: str) -> Optional[str]: - """Parse dates into a standard readable format.""" - if not date_str or not isinstance(date_str, str): - return None - try: - if 'T' in date_str: - return datetime.datetime.fromisoformat(date_str.replace('Z', '').split('.')[0]).strftime('%Y-%m-%d %H:%M:%S') - if '-' in date_str and len(date_str) == 10: - return datetime.datetime.strptime(date_str, '%Y-%m-%d').strftime('%Y-%m-%d %H:%M:%S') - if '/' in date_str and len(date_str) == 10: - return datetime.datetime.strptime(date_str, '%m/%d/%Y').strftime('%Y-%m-%d %H:%M:%S') - except Exception: - return date_str - return date_str - - -def fetch_endpoint_data(url: str) -> Optional[Union[Dict, list]]: - """Generic GET request to Braiins API endpoint.""" - try: - resp = requests.get(url, timeout=15) - resp.raise_for_status() - return resp.json() - except Exception as e: - logger.error(f"Failed to fetch {url}: {e}") - return None - - -def clean_and_process_data(data: Union[Dict, list]) -> Union[Dict, list]: - """Clean all keys and values in the fetched data.""" - if isinstance(data, dict): - return {clean_value(k): clean_value(v) for k, v in data.items()} - elif isinstance(data, list): - cleaned_list = [] - for item in data: - if isinstance(item, dict): - cleaned_list.append({clean_value(k): clean_value(v) for k, v in item.items()}) - else: - cleaned_list.append(clean_value(item)) - return cleaned_list - return clean_value(data) - - -# ----------------------------------------- -# Main data fetcher -# ----------------------------------------- -def fetch_clean_data(history_limit: int = 30) -> Dict[str, Union[Dict, list]]: - """ - Fetch and clean data from 7 selected Braiins endpoints. - For historical data, it limits the number of records. - Returns a dictionary ready to be passed into an LLM. - """ - logger.info("Fetching Bitcoin network data from Braiins...") - results = {} - - for key, url in BRAIINS_APIS.items(): - logger.info(f"Fetching {key} ...") - raw_data = fetch_endpoint_data(url) - if raw_data is not None: - # --- START OF THE NEW CODE --- - # If the endpoint is for historical data, limit the number of records - if "history" in key and isinstance(raw_data, list): - logger.info(f"Limiting {key} data to the last {history_limit} records.") - raw_data = raw_data[-history_limit:] - # --- END OF THE NEW CODE --- - - results[key] = clean_and_process_data(raw_data) - else: - results[key] = {"error": "Failed to fetch"} - - logger.info("All data fetched and cleaned successfully.") - return results - -# ----------------------------------------- -# Local test run (optional) -# ----------------------------------------- -if __name__ == "__main__": - data = fetch_clean_data() - print("Sample keys fetched:", list(data.keys())) diff --git a/community-contributions/muhammad_qasim_sheikh/Week 1/Day 2/ReadME.md b/community-contributions/muhammad_qasim_sheikh/Week 1/Day 2/ReadME.md deleted file mode 100644 index 0d1952a..0000000 --- a/community-contributions/muhammad_qasim_sheikh/Week 1/Day 2/ReadME.md +++ /dev/null @@ -1,50 +0,0 @@ -# **Automated Bitcoin Daily Summary Generator** - -This project automates the process of generating a daily summary of the Bitcoin network's status. It fetches real-time data from multiple public API endpoints, processes it, and then uses a Large Language Model (LLM) to generate a clear, structured, and human-readable report in Markdown format. - -## **Project Overview** - -The core goal of this project is to provide a snapshot of key Bitcoin metrics without manual analysis. By leveraging the Braiins Public API for data and OpenAI's GPT models for summarization, it can produce insightful daily reports covering market trends, network health, miner revenue, and future outlooks like the next halving event. - -### **Key Features** - -- **Automated Data Fetching**: Pulls data from 7 different Braiins API endpoints covering price, hashrate, difficulty, transaction fees, and more. -- **Data Cleaning**: Pre-processes the raw JSON data to make it clean and suitable for the LLM. -- **Intelligent Summarization**: Uses an advanced LLM to analyze the data and generate a structured report with explanations for technical terms. -- **Dynamic Dating**: The report is always dated for the day it is run, providing a timely summary regardless of the timestamps in the source data. -- **Markdown Output**: Generates a clean, well-formatted Markdown file that is easy to read or integrate into other systems. - -## **How It Works** - -The project is split into two main files: - -1. **utils.py**: A utility script responsible for all data fetching and cleaning operations. - - It defines the Braiins API endpoints to be queried. - - It contains functions to handle HTTP requests, parse JSON responses, and clean up keys and values to ensure consistency. -2. **day_1_bitcoin_daily_brief.ipynb**: A Jupyter Notebook that acts as the main orchestrator. - - It imports the necessary functions from utils.py. - - It calls fetch_clean_data() to get the latest Bitcoin network data. - - It constructs a detailed system and user prompt for the LLM, explicitly instructing it on the desired format and, crucially, to use the current date for the summary. - - It sends the data and prompt to the OpenAI API. - - It receives the generated summary and displays it as formatted Markdown. - -## **Setup and Usage** - -To run this project, you will need to have Python and the required libraries installed. - -### **1\. Prerequisites** - -- Python 3.x -- Jupyter Notebook or JupyterLab - -### **2\. Installation** - -- Install the necessary Python libraries: pip install requests openai python-dotenv jupyter - -### **3\. Configuration** - -You need an API key from OpenAI to use the summarization feature. - -1. Create a file named .env in the root directory of the project. -2. Add your OpenAI API key to the .env file as follows: - OPENAI_API_KEY='your_openai_api_key_here' diff --git a/community-contributions/muhammad_qasim_sheikh/Week 1/Day 2/day_2_bitcoin_daily_brief_with_ollama.ipynb b/community-contributions/muhammad_qasim_sheikh/Week 1/Day 2/day_2_bitcoin_daily_brief_with_ollama.ipynb deleted file mode 100644 index 548f963..0000000 --- a/community-contributions/muhammad_qasim_sheikh/Week 1/Day 2/day_2_bitcoin_daily_brief_with_ollama.ipynb +++ /dev/null @@ -1,152 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "id": "abaef96b", - "metadata": {}, - "source": [ - "## Importing The Libraries" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "id": "f90c541b", - "metadata": {}, - "outputs": [], - "source": [ - "import os\n", - "import datetime\n", - "from utils import fetch_clean_data\n", - "from openai import OpenAI\n", - "from IPython.display import Markdown, display\n", - "import json" - ] - }, - { - "cell_type": "markdown", - "id": "6e6c864b", - "metadata": {}, - "source": [ - "## Configuration" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "id": "be62299d", - "metadata": {}, - "outputs": [], - "source": [ - "client = OpenAI(base_url='http://localhost:11434/v1', api_key = 'ollama')" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "id": "3aa8e3e2", - "metadata": {}, - "outputs": [], - "source": [ - "def generate_markdown_summary(data: dict, today_date_str: str) -> str:\n", - " \"\"\"\n", - " Send cleaned Bitcoin data to an LLM and receive a Markdown summary.\n", - " \"\"\"\n", - "\n", - " system_prompt = f\"\"\"\n", - " You are a professional crypto analyst. Your job is to read the provided Bitcoin network data \n", - " and write a clear, structured report that can be read directly as a daily summary.\n", - "\n", - " Following are the rules that you must adhere to:\n", - " - **IMPORTANT**: The summary title MUST use today's date: {today_date_str}. The title must be: \"Bitcoin Daily Summary - {today_date_str}\".\n", - " - **CRITICAL**: Do NOT infer the reporting period from the data. The data contains historical records, but your report is for {today_date_str}.\n", - " - Include **headings** for sections like \"Market Overview\", \"Network Metrics Explained\", \"Miner Revenue Trends\", and \"Halving Outlook\".\n", - " - Use **bullet points** for key metrics.\n", - " - Use a **table** for historical or time-series data if available.\n", - " - Explain important terms (like hashrate, difficulty, transaction fees) in plain language.\n", - "\n", - " Respond in markdown. Do not wrap the markdown in a code block - respond just with the markdown.\n", - " \"\"\"\n", - "\n", - " # Convert the Python data dictionary into a clean JSON string for the prompt\n", - " data_str = json.dumps(data, indent=2)\n", - "\n", - " user_prompt = f\"\"\"\n", - " Today's date is {today_date_str}. Use this as the reference point for the report.\n", - "\n", - " The following data may contain historical records (e.g., from 2024), \n", - " but you must treat it as background context and write the summary as of {today_date_str}.\n", - "\n", - " Here is the data for you to summarize: \n", - " {data_str}\n", - " \"\"\"\n", - " \n", - " response = client.chat.completions.create(\n", - " model= \"llama3.2\", \n", - " messages=[\n", - " {\"role\": \"system\", \"content\": system_prompt},\n", - " {\"role\": \"user\", \"content\": user_prompt}\n", - " ]\n", - " )\n", - "\n", - " markdown_text = response.choices[0].message.content.strip()\n", - " return markdown_text" - ] - }, - { - "cell_type": "markdown", - "id": "1e8c2d7d", - "metadata": {}, - "source": [ - "## Main Function" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "05059ed9", - "metadata": {}, - "outputs": [], - "source": [ - "def main():\n", - " # 0. Get today's date as a string\n", - " today_str = datetime.datetime.now().strftime('%B %d, %Y')\n", - " \n", - " # 1. Fetch and clean data\n", - " print(\"Fetching Bitcoin data...\")\n", - " data = fetch_clean_data()\n", - "\n", - " # 2. Generate Markdown summary\n", - " print(\"Generating LLM summary...\")\n", - " markdown_report = generate_markdown_summary(data, today_str)\n", - "\n", - " # 3. Display Output\n", - " display(Markdown(markdown_report))\n", - "\n", - "if __name__ == \"__main__\":\n", - " main()" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "llm-engineering", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.12.12" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} diff --git a/community-contributions/muhammad_qasim_sheikh/Week 1/Day 2/utils.py b/community-contributions/muhammad_qasim_sheikh/Week 1/Day 2/utils.py deleted file mode 100644 index ad16069..0000000 --- a/community-contributions/muhammad_qasim_sheikh/Week 1/Day 2/utils.py +++ /dev/null @@ -1,113 +0,0 @@ -# utils.py - -import requests -import re -import datetime -import logging -from typing import Dict, Optional, Union - -# ----------------------------------------- -# Logging setup -# ----------------------------------------- -logging.basicConfig(level=logging.INFO) -logger = logging.getLogger(__name__) - -# ----------------------------------------- -# Braiins API endpoints (7 selected) -# ----------------------------------------- -BRAIINS_APIS = { - 'price_stats': 'https://insights.braiins.com/api/v1.0/price-stats', - 'hashrate_stats': 'https://insights.braiins.com/api/v1.0/hashrate-stats', - 'difficulty_stats': 'https://insights.braiins.com/api/v1.0/difficulty-stats', - 'transaction_fees_history': 'https://insights.braiins.com/api/v1.0/transaction-fees-history', - 'daily_revenue_history': 'https://insights.braiins.com/api/v1.0/daily-revenue-history', - 'hashrate_value_history': 'https://insights.braiins.com/api/v1.0/hashrate-value-history', - 'halvings': 'https://insights.braiins.com/api/v2.0/halvings' -} - - -# ----------------------------------------- -# Utility Functions -# ----------------------------------------- -def clean_value(value): - """Clean strings, remove brackets/quotes and standardize whitespace.""" - if value is None: - return "" - s = str(value) - s = s.replace(",", " ") - s = re.sub(r"[\[\]\{\}\(\)]", "", s) - s = s.replace('"', "").replace("'", "") - s = re.sub(r"\s+", " ", s) - return s.strip() - - -def parse_date(date_str: str) -> Optional[str]: - """Parse dates into a standard readable format.""" - if not date_str or not isinstance(date_str, str): - return None - try: - if 'T' in date_str: - return datetime.datetime.fromisoformat(date_str.replace('Z', '').split('.')[0]).strftime('%Y-%m-%d %H:%M:%S') - if '-' in date_str and len(date_str) == 10: - return datetime.datetime.strptime(date_str, '%Y-%m-%d').strftime('%Y-%m-%d %H:%M:%S') - if '/' in date_str and len(date_str) == 10: - return datetime.datetime.strptime(date_str, '%m/%d/%Y').strftime('%Y-%m-%d %H:%M:%S') - except Exception: - return date_str - return date_str - - -def fetch_endpoint_data(url: str) -> Optional[Union[Dict, list]]: - """Generic GET request to Braiins API endpoint.""" - try: - resp = requests.get(url, timeout=15) - resp.raise_for_status() - return resp.json() - except Exception as e: - logger.error(f"Failed to fetch {url}: {e}") - return None - - -def clean_and_process_data(data: Union[Dict, list]) -> Union[Dict, list]: - """Clean all keys and values in the fetched data.""" - if isinstance(data, dict): - return {clean_value(k): clean_value(v) for k, v in data.items()} - elif isinstance(data, list): - cleaned_list = [] - for item in data: - if isinstance(item, dict): - cleaned_list.append({clean_value(k): clean_value(v) for k, v in item.items()}) - else: - cleaned_list.append(clean_value(item)) - return cleaned_list - return clean_value(data) - - -# ----------------------------------------- -# Main data fetcher -# ----------------------------------------- -def fetch_clean_data() -> Dict[str, Union[Dict, list]]: - """ - Fetch and clean data from 7 selected Braiins endpoints. - Returns a dictionary ready to be passed into an LLM. - """ - logger.info("Fetching Bitcoin network data from Braiins...") - results = {} - - for key, url in BRAIINS_APIS.items(): - logger.info(f"Fetching {key} ...") - raw_data = fetch_endpoint_data(url) - if raw_data is not None: - results[key] = clean_and_process_data(raw_data) - else: - results[key] = {"error": "Failed to fetch"} - - logger.info("All data fetched and cleaned successfully.") - return results - -# ----------------------------------------- -# Local test run (optional) -# ----------------------------------------- -if __name__ == "__main__": - data = fetch_clean_data() - print("Sample keys fetched:", list(data.keys())) diff --git a/community-contributions/muhammad_qasim_sheikh/Week 1/Day 5/brochure.ipynb b/community-contributions/muhammad_qasim_sheikh/Week 1/Day 5/brochure.ipynb new file mode 100644 index 0000000..9dfa180 --- /dev/null +++ b/community-contributions/muhammad_qasim_sheikh/Week 1/Day 5/brochure.ipynb @@ -0,0 +1,207 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 9, + "id": "57499cf2", + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "import json\n", + "from dotenv import load_dotenv\n", + "from IPython.display import Markdown, display, update_display\n", + "from scraper import fetch_website_links, fetch_website_contents\n", + "from openai import OpenAI" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "310a13f3", + "metadata": {}, + "outputs": [], + "source": [ + "load_dotenv(override=True)\n", + "api_key = os.getenv('OPENAI_API_KEY')\n", + "\n", + "client = OpenAI()" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "79226a7f", + "metadata": {}, + "outputs": [], + "source": [ + "link_analyzer_prompt = \"\"\"\n", + "You are a skilled research analyst. Your task is to identify the most useful introductory links for a given topic from a list of URLs. \n", + "You must ignore forum posts, product pages, and social media links. Focus on high-quality articles, documentation, and educational resources.\n", + "Respond ONLY with a JSON object in the following format:\n", + "{\n", + " \"links\": [\n", + " {\"type\": \"overview_article\", \"url\": \"https://...\"},\n", + " {\"type\": \"technical_docs\", \"url\": \"https://...\"},\n", + " {\"type\": \"history_summary\", \"url\": \"https://...\"}\n", + " ]\n", + "}\n", + "\"\"\"" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "73d02b52", + "metadata": {}, + "outputs": [], + "source": [ + "briefing_prompt = \"\"\"\n", + "You are an expert intelligence analyst. You will be given raw text from several articles about a topic. \n", + "Your mission is to synthesize this information into a clear and structured research brief. \n", + "The brief must contain the following sections in Markdown:\n", + "\n", + "Research Brief: {topic}\n", + "\n", + "1. Executive Summary\n", + "(A one-paragraph overview of the entire topic.)\n", + "\n", + "2. Key Concepts\n", + "(Use bullet points to list and explain the most important terms and ideas.)\n", + "\n", + "3. Important Figures / Events\n", + "(List the key people, organizations, or historical events relevant to the topic.)\n", + "\n", + "4. Further Reading\n", + "(Provide a list of the original URLs you analyzed for deeper study.)\n", + "\"\"\"" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "ab04efb6", + "metadata": {}, + "outputs": [], + "source": [ + "def get_relevant_links(topic: str, starting_url: str) -> dict:\n", + " \n", + " # getting all links from the starting URL\n", + " links_on_page = fetch_website_links(starting_url)\n", + " \n", + " # user prompt for the Link Analyst\n", + " user_prompt = f\"\"\"\n", + " Please analyze the following links related to the topic \"{topic}\" and return the most relevant ones for a research brief.\n", + " The main URL is {starting_url}. Make sure all returned URLs are absolute.\n", + "\n", + " Links:\n", + " {\"\\n\".join(links_on_page)}\n", + " \"\"\"\n", + " \n", + " response = client.chat.completions.create(\n", + " model=\"gpt-4o-mini\", \n", + " messages=[\n", + " {\"role\": \"system\", \"content\": link_analyzer_prompt},\n", + " {\"role\": \"user\", \"content\": user_prompt}\n", + " ],\n", + " response_format={\"type\": \"json_object\"}\n", + " )\n", + " \n", + " result_json = response.choices[0].message.content\n", + " relevant_links = json.loads(result_json)\n", + " return relevant_links" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "ef6ef363", + "metadata": {}, + "outputs": [], + "source": [ + "def get_all_content(links_data: dict) -> str:\n", + " all_content = \"\"\n", + " original_urls = []\n", + "\n", + " for link in links_data.get(\"links\", []):\n", + " url = link.get(\"url\")\n", + " if url:\n", + " original_urls.append(url)\n", + " content = fetch_website_contents(url)\n", + " all_content += f\"Content from {url} \\n{content}\\n\\n\"\n", + " \n", + " all_content += f\"Original URLs for Reference\\n\" + \"\\n\".join(original_urls)\n", + " return all_content" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "c2020492", + "metadata": {}, + "outputs": [], + "source": [ + "def create_research_brief(topic: str, starting_url: str):\n", + " relevant_links = get_relevant_links(topic, starting_url)\n", + " full_content = get_all_content(relevant_links)\n", + "\n", + " user_prompt = f\"\"\"\n", + " Please create a research brief on the topic \"{topic}\" using the following content.\n", + " Remember to include the original URLs in the 'Further Reading' section.\n", + "\n", + " Content:\n", + " {full_content[:15000]}\n", + " \"\"\"\n", + " \n", + " stream = client.chat.completions.create(\n", + " model=\"gpt-4o-mini\",\n", + " messages=[\n", + " {\"role\": \"system\", \"content\": briefing_prompt.format(topic=topic)},\n", + " {\"role\": \"user\", \"content\": user_prompt}\n", + " ],\n", + " stream=True\n", + " )\n", + " \n", + " response = \"\"\n", + " display_handle = display(Markdown(\"\"), display_id=True)\n", + " for chunk in stream:\n", + " response += chunk.choices[0].delta.content or ''\n", + " update_display(Markdown(response), display_id=display_handle.display_id)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "594e940c", + "metadata": {}, + "outputs": [], + "source": [ + "create_research_brief(\n", + " topic=\"The Rise of Artificial Intelligence\", \n", + " starting_url=\"https://en.wikipedia.org/wiki/Artificial_intelligence\"\n", + ")" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "llm-engineering", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.12" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/community-contributions/muhammad_qasim_sheikh/Week 1/Day 5/scraper.py b/community-contributions/muhammad_qasim_sheikh/Week 1/Day 5/scraper.py new file mode 100644 index 0000000..1ecc209 --- /dev/null +++ b/community-contributions/muhammad_qasim_sheikh/Week 1/Day 5/scraper.py @@ -0,0 +1,37 @@ +from bs4 import BeautifulSoup +import requests + + +# Standard headers to fetch a website +headers = { + "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/117.0.0.0 Safari/537.36" +} + + +def fetch_website_contents(url): + """ + Return the title and contents of the website at the given url; + truncate to 2,000 characters as a sensible limit + """ + response = requests.get(url, headers=headers) + soup = BeautifulSoup(response.content, "html.parser") + title = soup.title.string if soup.title else "No title found" + if soup.body: + for irrelevant in soup.body(["script", "style", "img", "input"]): + irrelevant.decompose() + text = soup.body.get_text(separator="\n", strip=True) + else: + text = "" + return (title + "\n\n" + text)[:2_000] + + +def fetch_website_links(url): + """ + Return the links on the webiste at the given url + I realize this is inefficient as we're parsing twice! This is to keep the code in the lab simple. + Feel free to use a class and optimize it! + """ + response = requests.get(url, headers=headers) + soup = BeautifulSoup(response.content, "html.parser") + links = [link.get("href") for link in soup.find_all("a")] + return [link for link in links if link] diff --git a/community-contributions/muhammad_qasim_sheikh/Week 2/day 5/Multi-modalAssistant_day5.ipynb b/community-contributions/muhammad_qasim_sheikh/Week 2/day 5/Multi-modalAssistant_day5.ipynb new file mode 100644 index 0000000..c7958eb --- /dev/null +++ b/community-contributions/muhammad_qasim_sheikh/Week 2/day 5/Multi-modalAssistant_day5.ipynb @@ -0,0 +1,337 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "1665a5cf", + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "import re\n", + "import time\n", + "import json\n", + "import sqlite3\n", + "from dotenv import load_dotenv\n", + "import gradio as gr\n", + "from openai import OpenAI" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5cb6632c", + "metadata": {}, + "outputs": [], + "source": [ + "load_dotenv()\n", + "client = OpenAI(api_key=os.getenv(\"OPENAI_API_KEY\"))\n", + "DB_PATH = \"nova_support.db\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2cd3ac8c", + "metadata": {}, + "outputs": [], + "source": [ + "def init_db():\n", + " conn = sqlite3.connect(DB_PATH)\n", + " cur = conn.cursor()\n", + " cur.execute(\"\"\"\n", + " CREATE TABLE IF NOT EXISTS tickets (\n", + " ticket_id TEXT PRIMARY KEY,\n", + " name TEXT,\n", + " company TEXT,\n", + " email TEXT,\n", + " issue TEXT,\n", + " priority TEXT,\n", + " status TEXT,\n", + " created_at TEXT\n", + " )\n", + " \"\"\")\n", + " conn.commit()\n", + " conn.close()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "70e0556c", + "metadata": {}, + "outputs": [], + "source": [ + "def new_ticket_id():\n", + " conn = sqlite3.connect(DB_PATH)\n", + " cur = conn.cursor()\n", + " cur.execute(\"SELECT COUNT(*) FROM tickets\")\n", + " count = cur.fetchone()[0]\n", + " conn.close()\n", + " return f\"RT-{1001 + count}\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "38525d5c", + "metadata": {}, + "outputs": [], + "source": [ + "def create_ticket(name, company, email, issue, priority=\"P3\"):\n", + " tid = new_ticket_id()\n", + " ts = time.strftime(\"%Y-%m-%d %H:%M:%S\")\n", + " conn = sqlite3.connect(DB_PATH)\n", + " cur = conn.cursor()\n", + " cur.execute(\"\"\"\n", + " INSERT INTO tickets (ticket_id, name, company, email, issue, priority, status, created_at)\n", + " VALUES (?, ?, ?, ?, ?, ?, ?, ?)\n", + " \"\"\", (tid, name, company, email, issue, priority.upper(), \"OPEN\", ts))\n", + " conn.commit()\n", + " conn.close()\n", + " return tid, ts" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "58e803c5", + "metadata": {}, + "outputs": [], + "source": [ + "def get_ticket(ticket_id):\n", + " conn = sqlite3.connect(DB_PATH)\n", + " cur = conn.cursor()\n", + " cur.execute(\"SELECT * FROM tickets WHERE ticket_id=?\", (ticket_id,))\n", + " row = cur.fetchone()\n", + " conn.close()\n", + " if not row:\n", + " return None\n", + " keys = [\"ticket_id\", \"name\", \"company\", \"email\", \"issue\", \"priority\", \"status\", \"created_at\"]\n", + " return dict(zip(keys, row))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b97601ff", + "metadata": {}, + "outputs": [], + "source": [ + "def synthesize_speech(text):\n", + " if not text.strip():\n", + " return None\n", + " output_path = Path(tempfile.gettempdir()) / \"nova_reply.mp3\"\n", + " with client.audio.speech.with_streaming_response.create(\n", + " model=\"gpt-4o-mini-tts\",\n", + " voice=\"alloy\",\n", + " input=text\n", + " ) as response:\n", + " response.stream_to_file(output_path)\n", + " return str(output_path)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e4e20aad", + "metadata": {}, + "outputs": [], + "source": [ + "SYSTEM_PROMPT = \"\"\"\n", + "You are Nova, the AI Support and Sales Assistant for Reallytics.ai.\n", + "You help customers with:\n", + "- Reporting issues (create tickets)\n", + "- Checking existing tickets\n", + "- Providing product/service information\n", + "- Explaining pricing ranges\n", + "- Reassuring integration compatibility with client systems\n", + "Respond in a professional, business tone.\n", + "\"\"\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0d1c094d", + "metadata": {}, + "outputs": [], + "source": [ + "def detect_intent(message):\n", + " text = message.lower()\n", + " if any(k in text for k in [\"create ticket\", \"open ticket\", \"new ticket\", \"issue\", \"problem\"]):\n", + " return \"create_ticket\"\n", + " if re.search(r\"rt-\\d+\", text):\n", + " return \"check_ticket\"\n", + " if \"price\" in text or \"cost\" in text:\n", + " return \"pricing\"\n", + " if \"integration\" in text:\n", + " return \"integration\"\n", + " return \"general\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ed9114d5", + "metadata": {}, + "outputs": [], + "source": [ + "def chat(message, history, model, name, company, email):\n", + " history_msgs = [{\"role\": h[\"role\"], \"content\": h[\"content\"]} for h in history]\n", + " intent = detect_intent(message)\n", + "\n", + " if intent == \"create_ticket\":\n", + " priority = \"P2\" if \"urgent\" in message.lower() or \"high\" in message.lower() else \"P3\"\n", + " tid, ts = create_ticket(name, company, email, message, priority)\n", + " text = f\"A new support ticket has been created.\\nTicket ID: {tid}\\nCreated at: {ts}\\nStatus: OPEN\"\n", + " yield text, synthesize_speech(text)\n", + " return\n", + "\n", + " if intent == \"check_ticket\":\n", + " match = re.search(r\"(rt-\\d+)\", message.lower())\n", + " if match:\n", + " ticket_id = match.group(1).upper()\n", + " data = get_ticket(ticket_id)\n", + " if data:\n", + " text = (\n", + " f\"Ticket {ticket_id} Details:\\n\"\n", + " f\"Issue: {data['issue']}\\n\"\n", + " f\"Status: {data['status']}\\n\"\n", + " f\"Priority: {data['priority']}\\n\"\n", + " f\"Created at: {data['created_at']}\"\n", + " )\n", + " else:\n", + " text = f\"No ticket found with ID {ticket_id}.\"\n", + " else:\n", + " text = \"Please provide a valid ticket ID.\"\n", + " yield text, synthesize_speech(text)\n", + " return" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "280c7d2f", + "metadata": {}, + "outputs": [], + "source": [ + "def chat(message, history, model, name, company, email):\n", + " if not message.strip():\n", + " yield \"Please type a message to start.\", None\n", + " return\n", + "\n", + " history_msgs = [{\"role\": h[\"role\"], \"content\": h[\"content\"]} for h in history]\n", + " intent = detect_intent(message)\n", + " reply, audio_path = \"\", None\n", + "\n", + " if intent == \"create_ticket\":\n", + " priority = \"P2\" if \"urgent\" in message.lower() or \"high\" in message.lower() else \"P3\"\n", + " tid, ts = create_ticket(name, company, email, message, priority)\n", + " reply = f\"A new support ticket has been created.\\nTicket ID: {tid}\\nCreated at: {ts}\\nStatus: OPEN\"\n", + " audio_path = synthesize_speech(reply)\n", + " yield reply, audio_path\n", + " return\n", + "\n", + " if intent == \"check_ticket\":\n", + " match = re.search(r\"(rt-\\d+)\", message.lower())\n", + " if match:\n", + " ticket_id = match.group(1).upper()\n", + " data = get_ticket(ticket_id)\n", + " if data:\n", + " reply = (\n", + " f\"Ticket {ticket_id} Details:\\n\"\n", + " f\"Issue: {data['issue']}\\n\"\n", + " f\"Status: {data['status']}\\n\"\n", + " f\"Priority: {data['priority']}\\n\"\n", + " f\"Created at: {data['created_at']}\"\n", + " )\n", + " else:\n", + " reply = f\"No ticket found with ID {ticket_id}.\"\n", + " else:\n", + " reply = \"Please provide a valid ticket ID.\"\n", + " audio_path = synthesize_speech(reply)\n", + " yield reply, audio_path\n", + " return\n", + "\n", + " messages = [{\"role\": \"system\", \"content\": SYSTEM_PROMPT}] + history_msgs + [{\"role\": \"user\", \"content\": message}]\n", + " stream = client.chat.completions.create(model=model, messages=messages, stream=True)\n", + "\n", + " full_reply = \"\"\n", + " for chunk in stream:\n", + " delta = chunk.choices[0].delta.content or \"\"\n", + " full_reply += delta\n", + " yield full_reply, None \n", + " audio_path = synthesize_speech(full_reply)\n", + " yield full_reply, audio_path " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0cb1977d", + "metadata": {}, + "outputs": [], + "source": [ + "init_db()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8a0557ba", + "metadata": {}, + "outputs": [], + "source": [ + "with gr.Blocks(title=\"Nova | Business AI Assistant\", theme=gr.themes.Soft()) as demo:\n", + " gr.Markdown(\"## Nova | Reallytics.ai Customer Support & Sales Assistant\")\n", + " gr.Markdown(\n", + " \"Nova helps clients create or track support tickets, understand services, and explore automation options. \"\n", + " \"Type your questions and Nova will respond in both text and voice.\"\n", + " )\n", + "\n", + " with gr.Row():\n", + " name = gr.Textbox(label=\"Your Name\", placeholder=\"Liam\")\n", + " company = gr.Textbox(label=\"Company (optional)\", placeholder=\"ABC Corp\")\n", + " email = gr.Textbox(label=\"Email\", placeholder=\"you@example.com\")\n", + "\n", + " model = gr.Dropdown([\"gpt-4o-mini\", \"gpt-4\", \"gpt-3.5-turbo\"], value=\"gpt-4o-mini\", label=\"Model\")\n", + "\n", + " audio_output = gr.Audio(label=\"Nova's Voice Reply\", autoplay=True, interactive=False)\n", + "\n", + " gr.ChatInterface(\n", + " fn=chat,\n", + " type=\"messages\",\n", + " additional_inputs=[model, name, company, email],\n", + " additional_outputs=[audio_output],\n", + " title=\"Chat with Nova\",\n", + " description=\"Ask about tickets, automation services, pricing, or integration and Nova will also speak her reply.\"\n", + " )\n", + "\n", + "if __name__ == \"__main__\":\n", + " demo.launch()" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "llm-engineering", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.12" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/community-contributions/muhammad_qasim_sheikh/Week 2/day 5/nova_support.db b/community-contributions/muhammad_qasim_sheikh/Week 2/day 5/nova_support.db new file mode 100644 index 0000000..0e0b415 Binary files /dev/null and b/community-contributions/muhammad_qasim_sheikh/Week 2/day 5/nova_support.db differ diff --git a/community-contributions/muhammad_qasim_sheikh/Week 2/day1/3way_conversation_day1.ipynb b/community-contributions/muhammad_qasim_sheikh/Week 2/day1/3way_conversation_day1.ipynb deleted file mode 100644 index 0eb6d74..0000000 --- a/community-contributions/muhammad_qasim_sheikh/Week 2/day1/3way_conversation_day1.ipynb +++ /dev/null @@ -1,144 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": null, - "id": "d59206dc", - "metadata": {}, - "outputs": [], - "source": [ - "import os\n", - "from dotenv import load_dotenv\n", - "from openai import OpenAI\n", - "import ollama\n", - "from IPython.display import Markdown, display" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "ad035727", - "metadata": {}, - "outputs": [], - "source": [ - "# Load keys\n", - "load_dotenv()\n", - "client = OpenAI(api_key=os.getenv(\"OPENAI_API_KEY\"))\n", - "ollama_via_openai = OpenAI(base_url='http://localhost:11434/v1', api_key = 'ollama')" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "3f521334", - "metadata": {}, - "outputs": [], - "source": [ - "# ---- SYSTEM PROMPTS ----\n", - "athena_system = \"\"\"\n", - "You are Athena, a strategic thinker and visionary. You seek meaning, long-term implications,\n", - "and practical wisdom in every discussion. Be concise (1-2 sentences).\n", - "\"\"\"\n", - "\n", - "loki_system = \"\"\"\n", - "You are Loki, a sarcastic trickster who mocks and challenges everyone else's opinions.\n", - "You use humor, wit, and irony to undermine serious arguments. Be concise (1-2 sentences).\n", - "\"\"\"\n", - "\n", - "orion_system = \"\"\"\n", - "You are Orion, a data-driven realist. You respond with evidence, statistics, or factual analysis.\n", - "If data is not available, make a logical deduction. Be concise (1-2 sentences).\n", - "\"\"\"" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "0a6d04f6", - "metadata": {}, - "outputs": [], - "source": [ - "# ---- INITIAL CONVERSATION ----\n", - "conversation = [\n", - " {\"role\": \"system\", \"name\": \"Athena\", \"content\": athena_system},\n", - " {\"role\": \"system\", \"name\": \"Loki\", \"content\": loki_system},\n", - " {\"role\": \"system\", \"name\": \"Orion\", \"content\": orion_system},\n", - " {\"role\": \"user\", \"content\": \"Topic: 'Why did the chicken cross the road?' Begin your discussion.\"}\n", - "]" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "e292a27b", - "metadata": {}, - "outputs": [], - "source": [ - "# ---- HELPER FUNCTIONS ----\n", - "def call_gpt(name, system_prompt, conversation):\n", - " \"\"\"Call GPT model with current conversation context.\"\"\"\n", - " messages = [{\"role\": \"system\", \"content\": system_prompt}]\n", - " messages += [{\"role\": \"user\", \"content\": f\"The conversation so far:\\n{format_conversation(conversation)}\\nNow respond as {name}.\"}]\n", - " resp = client.chat.completions.create(model=\"gpt-4o-mini\", messages=messages)\n", - " return resp.choices[0].message.content.strip()\n", - "\n", - "def call_ollama(name, system_prompt, conversation):\n", - " \"\"\"Call Ollama (Llama3.2) as a local model.\"\"\"\n", - " messages = [{\"role\": \"system\", \"content\": system_prompt}]\n", - " messages += [{\"role\": \"user\", \"content\": f\"The conversation so far:\\n{format_conversation(conversation)}\\nNow respond as {name}.\"}]\n", - " resp = ollama.chat(model=\"llama3.2\", messages=messages)\n", - " return resp['message']['content'].strip()\n", - "\n", - "def format_conversation(conv):\n", - " return \"\\n\".join([f\"{m.get('name', m['role']).upper()}: {m['content']}\" for m in conv if m['role'] != \"system\"])" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "f0eb4d72", - "metadata": {}, - "outputs": [], - "source": [ - "# ---- MAIN LOOP ----\n", - "rounds = 5\n", - "for i in range(rounds):\n", - " # Athena responds\n", - " athena_reply = call_gpt(\"Athena\", athena_system, conversation)\n", - " conversation.append({\"role\": \"assistant\", \"name\": \"Athena\", \"content\": athena_reply})\n", - " display(Markdown(f\"**Athena:** {athena_reply}\"))\n", - "\n", - " # Loki responds\n", - " loki_reply = call_ollama(\"Loki\", loki_system, conversation)\n", - " conversation.append({\"role\": \"assistant\", \"name\": \"Loki\", \"content\": loki_reply})\n", - " display(Markdown(f\"**Loki:** {loki_reply}\"))\n", - "\n", - " # Orion responds\n", - " orion_reply = call_gpt(\"Orion\", orion_system, conversation)\n", - " conversation.append({\"role\": \"assistant\", \"name\": \"Orion\", \"content\": orion_reply})\n", - " display(Markdown(f\"**Orion:** {orion_reply}\"))" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "llm-engineering", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.12.12" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} diff --git a/community-contributions/muhammad_qasim_sheikh/Week 2/day1/readme.md b/community-contributions/muhammad_qasim_sheikh/Week 2/day1/readme.md deleted file mode 100644 index 26c26a7..0000000 --- a/community-contributions/muhammad_qasim_sheikh/Week 2/day1/readme.md +++ /dev/null @@ -1,47 +0,0 @@ -# Multi-Agent Conversation Simulator (OpenAI + Ollama) - -## Project Overview - -This project is an experimental **multi-agent conversational simulation** built with **OpenAI GPT models** and a locally-hosted **Ollama LLM (Llama 3.2)**. It demonstrates how multiple AI personas can participate in a shared conversation, each with distinct roles, perspectives, and behaviors — producing a dynamic, evolving debate from different angles. - -The script orchestrates a **three-way dialogue** around a single topic (“Why did the chicken cross the road?”) between three agents, each powered by a different model and persona definition: - -- **Athena (OpenAI GPT-4o):** A strategic thinker who looks for deeper meaning, long-term consequences, and practical wisdom. -- **Loki (Ollama Llama 3.2):** A sarcastic trickster who mocks, questions, and challenges the others with wit and irony. -- **Orion (OpenAI GPT-4o):** A data-driven realist who grounds the discussion in facts, statistics, or logical deductions. - -## What’s Happening in the Code - -1. **Environment Setup** - - Loads the OpenAI API key from a `.env` file. - - Initializes OpenAI’s Python client and configures a local Ollama endpoint. - -2. **Persona System Prompts** - - Defines system prompts for each agent to give them unique personalities and communication styles. - - These prompts act as the “character definitions” for Athena, Loki, and Orion. - -3. **Conversation Initialization** - - Starts with a single conversation topic provided by the user. - - All three agents are aware of the discussion context and prior messages. - -4. **Conversation Loop** - - The conversation runs in multiple rounds (default: 5). - - In each round: - - **Athena (GPT)** responds first with a strategic viewpoint. - - **Loki (Ollama)** replies next, injecting sarcasm and skepticism. - - **Orion (GPT)** follows with a fact-based or analytical perspective. - - Each response is appended to the conversation history so future replies build on previous statements. - -5. **Dynamic Context Sharing** - - Each agent receives the **entire conversation so far** as context before generating a response. - - This ensures their replies are relevant, coherent, and responsive to what the others have said. - -6. **Output Rendering** - - Responses are displayed as Markdown in a readable, chat-like format for each speaker, round by round. - -## Key Highlights - -- Demonstrates **multi-agent orchestration** with different models working together in a single script. -- Uses **OpenAI GPT models** for reasoning and **Ollama (Llama 3.2)** for local, cost-free inference. -- Shows how **system prompts** and **context-aware message passing** can simulate realistic dialogues. -- Provides a template for experimenting with **AI characters**, **debate simulations**, or **collaborative agent systems**. diff --git a/community-contributions/muhammad_qasim_sheikh/Week 2/day2/gradio_simple_UI_day2.ipynb b/community-contributions/muhammad_qasim_sheikh/Week 2/day2/gradio_simple_UI_day2.ipynb deleted file mode 100644 index caeb07d..0000000 --- a/community-contributions/muhammad_qasim_sheikh/Week 2/day2/gradio_simple_UI_day2.ipynb +++ /dev/null @@ -1,224 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": null, - "id": "4ef1e715", - "metadata": {}, - "outputs": [], - "source": [ - "import os\n", - "import gradio as gr\n", - "from openai import OpenAI\n", - "from dotenv import load_dotenv" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "d3426558", - "metadata": {}, - "outputs": [], - "source": [ - "# Load API key\n", - "load_dotenv()\n", - "client = OpenAI(api_key=os.getenv(\"OPENAI_API_KEY\"))" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "e18a59a3", - "metadata": {}, - "outputs": [], - "source": [ - "# -------------------------------\n", - "# Helper: Prompt Builder\n", - "# -------------------------------\n", - "def build_prompt(task, topic, tone, audience):\n", - " task_prompts = {\n", - " \"Brochure\": f\"Write a compelling marketing brochure about {topic}.\",\n", - " \"Blog Post\": f\"Write a blog post on {topic} with engaging storytelling and useful insights.\",\n", - " \"Product Comparison\": f\"Write a product comparison summary focusing on {topic}, including pros, cons, and recommendations.\",\n", - " \"Idea Brainstorm\": f\"Brainstorm creative ideas or solutions related to {topic}.\"\n", - " }\n", - " base = task_prompts.get(task, \"Write something creative.\")\n", - " if tone:\n", - " base += f\" Use a {tone} tone.\"\n", - " if audience:\n", - " base += f\" Tailor it for {audience}.\"\n", - " return base" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "65a27bfb", - "metadata": {}, - "outputs": [], - "source": [ - "# -------------------------------\n", - "# Generate with multiple models\n", - "# -------------------------------\n", - "def generate_stream(task, topic, tone, audience, model):\n", - " if not topic.strip():\n", - " yield \"⚠️ Please enter a topic.\"\n", - " return\n", - "\n", - " prompt = build_prompt(task, topic, tone, audience)\n", - "\n", - " stream = client.chat.completions.create(\n", - " model=model,\n", - " messages=[\n", - " {\"role\": \"system\", \"content\": \"You are a helpful assistant.\"},\n", - " {\"role\": \"user\", \"content\": prompt}\n", - " ],\n", - " max_tokens=800,\n", - " stream=True\n", - " )\n", - "\n", - " result = \"\"\n", - " for chunk in stream:\n", - " result += chunk.choices[0].delta.content or \"\"\n", - " yield result" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "9e15abee", - "metadata": {}, - "outputs": [], - "source": [ - "# -------------------------------\n", - "# Refinement logic\n", - "# -------------------------------\n", - "def refine_stream(original_text, instruction, model):\n", - " if not original_text.strip():\n", - " yield \"⚠️ Please paste the text you want to refine.\"\n", - " return\n", - " if not instruction.strip():\n", - " yield \"⚠️ Please provide a refinement instruction.\"\n", - " return\n", - "\n", - " refined_prompt = f\"Refine the following text based on this instruction: {instruction}\\n\\nText:\\n{original_text}\"\n", - "\n", - " stream = client.chat.completions.create(\n", - " model=model,\n", - " messages=[\n", - " {\"role\": \"system\", \"content\": \"You are a writing assistant.\"},\n", - " {\"role\": \"user\", \"content\": refined_prompt}\n", - " ],\n", - " max_tokens=800,\n", - " stream=True\n", - " )\n", - "\n", - " result = \"\"\n", - " for chunk in stream:\n", - " result += chunk.choices[0].delta.content or \"\"\n", - " yield result\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "8ee02feb", - "metadata": {}, - "outputs": [], - "source": [ - "# -------------------------------\n", - "# Gradio UI\n", - "# -------------------------------\n", - "with gr.Blocks(title=\"AI Creative Studio\") as demo:\n", - " gr.Markdown(\"# AI Creative Studio\\nGenerate marketing content, blog posts, or creative ideas — streamed in real-time!\")\n", - "\n", - " with gr.Row():\n", - " task = gr.Dropdown(\n", - " [\"Brochure\", \"Blog Post\", \"Product Comparison\", \"Idea Brainstorm\"],\n", - " label=\"Task Type\",\n", - " value=\"Brochure\"\n", - " )\n", - " topic = gr.Textbox(label=\"Topic\", placeholder=\"e.g., Electric Cars, AI in Education...\")\n", - " with gr.Row():\n", - " tone = gr.Textbox(label=\"Tone (optional)\", placeholder=\"e.g., professional, casual, humorous...\")\n", - " audience = gr.Textbox(label=\"Target Audience (optional)\", placeholder=\"e.g., investors, students, developers...\")\n", - "\n", - " model = gr.Dropdown(\n", - " [\"gpt-4o-mini\", \"gpt-3.5-turbo\", \"gpt-4\"],\n", - " label=\"Choose a model\",\n", - " value=\"gpt-4o-mini\"\n", - " )\n", - "\n", - " generate_btn = gr.Button(\"Generate Content\")\n", - " output_md = gr.Markdown(label=\"Generated Content\", show_label=True)\n", - "\n", - " generate_btn.click(\n", - " fn=generate_stream,\n", - " inputs=[task, topic, tone, audience, model],\n", - " outputs=output_md\n", - " )\n", - "\n", - " gr.Markdown(\"---\\n## Refine Your Content\")\n", - "\n", - " original_text = gr.Textbox(\n", - " label=\"Original Content\",\n", - " placeholder=\"Paste content you want to refine...\",\n", - " lines=10\n", - " )\n", - " instruction = gr.Textbox(\n", - " label=\"Refinement Instruction\",\n", - " placeholder=\"e.g., Make it shorter and more persuasive.\",\n", - " )\n", - " refine_model = gr.Dropdown(\n", - " [\"gpt-4o-mini\", \"gpt-3.5-turbo\", \"gpt-4\"],\n", - " label=\"Model for Refinement\",\n", - " value=\"gpt-4o-mini\"\n", - " )\n", - "\n", - " refine_btn = gr.Button(\"Refine\")\n", - " refined_output = gr.Markdown(label=\"Refined Content\", show_label=True)\n", - "\n", - " refine_btn.click(\n", - " fn=refine_stream,\n", - " inputs=[original_text, instruction, refine_model],\n", - " outputs=refined_output\n", - " )\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "55d42c7e", - "metadata": {}, - "outputs": [], - "source": [ - "# -------------------------------\n", - "# Launch the App\n", - "# -------------------------------\n", - "if __name__ == \"__main__\":\n", - " demo.launch()" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "llm-engineering", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.12.12" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} diff --git a/community-contributions/muhammad_qasim_sheikh/Week 2/day2/readme.md b/community-contributions/muhammad_qasim_sheikh/Week 2/day2/readme.md deleted file mode 100644 index dc0bf08..0000000 --- a/community-contributions/muhammad_qasim_sheikh/Week 2/day2/readme.md +++ /dev/null @@ -1,48 +0,0 @@ -# AI Creative Studio - -## Project Overview - -AI Creative Studio is a web-based application built with Gradio that allows users to generate and refine high-quality written content in real time using OpenAI language models. It is designed as a flexible creative tool for content creation tasks such as writing brochures, blog posts, product comparisons, and brainstorming ideas. The application also supports interactive refinement, enabling users to improve or adapt existing text based on specific instructions. - -The core idea is to combine the power of OpenAI models with an intuitive, user-friendly interface that streams responses as they are generated. This provides a fast, engaging, and highly interactive writing experience without waiting for the entire response to complete before it appears. - ---- - -## What’s Happening in the Project - -1. **Environment Setup and Model Initialization** - - The application loads the OpenAI API key from a `.env` file and initializes the OpenAI client for model interactions. - - Supported models include `gpt-4o-mini`, `gpt-3.5-turbo`, and `gpt-4`, which the user can select from a dropdown menu. - -2. **Prompt Construction and Content Generation** - - The `build_prompt` function constructs a task-specific prompt based on the user’s choices: content type (brochure, blog post, etc.), topic, tone, and target audience. - - Once the user provides the inputs and selects a model, the application sends the prompt to the model. - - The model’s response is streamed back incrementally, showing text chunk by chunk for a real-time generation experience. - -3. **Content Refinement Feature** - - Users can paste existing text and provide a refinement instruction (e.g., “make it more persuasive” or “summarize it”). - - The application then streams an improved version of the text, following the instruction, allowing users to iterate and polish content efficiently. - -4. **Gradio User Interface** - - The app is built using Gradio Blocks, providing an organized and interactive layout. - - Key UI elements include: - - Task selection dropdown for choosing the type of content. - - Text inputs for topic, tone, and target audience. - - Model selection dropdown for choosing a specific OpenAI model. - - Real-time markdown display of generated content. - - A refinement panel for improving existing text. - -5. **Streaming Workflow** - - Both generation and refinement use OpenAI’s streaming API to display the model’s response as it’s produced. - - This provides an immediate and responsive user experience, allowing users to see results build up in real time rather than waiting for the entire completion. - ---- - -### Key Features -- Real-time streaming responses for fast and interactive content creation. -- Multiple content generation modes: brochure, blog post, product comparison, and idea brainstorming. -- Customization options for tone and audience to tailor the writing style. -- Interactive refinement tool to enhance or transform existing text. -- Clean and intuitive web interface powered by Gradio. - -AI Creative Studio demonstrates how large language models can be integrated into user-facing applications to support creative workflows and improve productivity in content generation and editing. diff --git a/community-contributions/muhammad_qasim_sheikh/Week 2/day3/ChatUI_day3.ipynb b/community-contributions/muhammad_qasim_sheikh/Week 2/day3/ChatUI_day3.ipynb deleted file mode 100644 index c76b68e..0000000 --- a/community-contributions/muhammad_qasim_sheikh/Week 2/day3/ChatUI_day3.ipynb +++ /dev/null @@ -1,137 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": null, - "id": "6f612c5a", - "metadata": {}, - "outputs": [], - "source": [ - "import os\n", - "import gradio as gr\n", - "from dotenv import load_dotenv\n", - "from openai import OpenAI" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "39c144fd", - "metadata": {}, - "outputs": [], - "source": [ - "# Load API Key\n", - "load_dotenv()\n", - "client = OpenAI(api_key=os.getenv(\"OPENAI_API_KEY\"))" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "f656e0d1", - "metadata": {}, - "outputs": [], - "source": [ - "# -------------------------------\n", - "# 1. System Prompt (Business Context)\n", - "# -------------------------------\n", - "system_message = \"\"\"\n", - "You are Nova, an AI Sales & Solutions Consultant for Reallytics.ai a company specializing in building\n", - "custom AI chatbots, voice assistants, data dashboards, and automation solutions for businesses.\n", - "You are professional, insightful, and always focused on solving the user's business challenges.\n", - "First, try to understand their use case. Then suggest relevant solutions from our services with clear value propositions.\n", - "If the user is unsure, give them examples of how similar businesses have benefited from AI.\n", - "\"\"\"\n", - "\n", - "MODEL = \"gpt-4o-mini\"" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "f2faba29", - "metadata": {}, - "outputs": [], - "source": [ - "# -------------------------------\n", - "# 2. Smart Chat Function (Streaming)\n", - "# -------------------------------\n", - "def chat(message, history):\n", - " # Convert Gradio's chat history to OpenAI format\n", - " history_messages = [{\"role\": h[\"role\"], \"content\": h[\"content\"]} for h in history]\n", - "\n", - " # Adjust system message based on context dynamically\n", - " relevant_system_message = system_message\n", - " if \"price\" in message.lower():\n", - " relevant_system_message += (\n", - " \" If the user asks about pricing, explain that pricing depends on project complexity, \"\n", - " \"but typical POCs start around $2,000 - $5,000, and full enterprise deployments scale beyond that.\"\n", - " )\n", - " if \"integration\" in message.lower():\n", - " relevant_system_message += (\n", - " \" If integration is mentioned, reassure the user that our solutions are built to integrate seamlessly with CRMs, ERPs, or internal APIs.\"\n", - " )\n", - "\n", - " # Compose final messages\n", - " messages = [{\"role\": \"system\", \"content\": relevant_system_message}] + history_messages + [\n", - " {\"role\": \"user\", \"content\": message}\n", - " ]\n", - "\n", - " # Stream the response\n", - " stream = client.chat.completions.create(\n", - " model=MODEL,\n", - " messages=messages,\n", - " stream=True\n", - " )\n", - "\n", - " response = \"\"\n", - " for chunk in stream:\n", - " response += chunk.choices[0].delta.content or \"\"\n", - " yield response" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "b9d9515e", - "metadata": {}, - "outputs": [], - "source": [ - "# -------------------------------\n", - "# 3. Gradio Chat UI\n", - "# -------------------------------\n", - "with gr.Blocks(title=\"AI Business Assistant\") as demo:\n", - " gr.Markdown(\"# AI Business Assistant\\nYour intelligent sales and solution consultant, powered by OpenAI.\")\n", - "\n", - " \n", - "gr.ChatInterface(\n", - " fn=chat,\n", - " type=\"messages\",\n", - " title=\"Business AI Consultant\",\n", - " description=\"Ask about automation, chatbots, dashboards, or voice AI Nova will help you discover the right solution.\"\n", - ").launch()\n" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "llm-engineering", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.12.12" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} diff --git a/community-contributions/muhammad_qasim_sheikh/Week 2/day3/readme.md b/community-contributions/muhammad_qasim_sheikh/Week 2/day3/readme.md deleted file mode 100644 index abb8189..0000000 --- a/community-contributions/muhammad_qasim_sheikh/Week 2/day3/readme.md +++ /dev/null @@ -1,42 +0,0 @@ -# AI Business Assistant - -## Project Overview - -This project is a prototype of an **AI-powered business consultant chatbot** built with **Gradio** and **OpenAI**. The assistant, named **Nova**, is designed to act as a virtual sales and solutions consultant for a company offering AI services such as chatbots, voice assistants, dashboards, and automation tools. - -The purpose of the project is to demonstrate how an LLM (Large Language Model) can be adapted for a business context by carefully designing the **system prompt** and providing **dynamic behavior** based on user inputs. The chatbot responds to user queries in real time with streaming responses, making it interactive and natural to use. - - -## What’s Happening in the Code - -1. **Environment Setup** - - The code loads the OpenAI API key from a `.env` file. - - The `OpenAI` client is initialized for communication with the language model. - - The chosen model is `gpt-4o-mini`. - -2. **System Prompt for Business Context** - - The assistant is given a clear identity: *Nova, an AI Sales & Solutions Consultant for Reallytics.ai*. - - The system prompt defines Nova’s tone (professional, insightful) and role (understand user needs, propose relevant AI solutions, share examples). - -3. **Dynamic Chat Function** - - The `chat()` function processes user input and the conversation history. - - It modifies the system prompt dynamically: - - If the user mentions **price**, Nova explains pricing ranges and factors. - - If the user mentions **integration**, Nova reassures the user about system compatibility. - - Messages are formatted for the OpenAI API, combining system, history, and user inputs. - - Responses are streamed back chunk by chunk, so users see the assistant typing in real time. - -4. **Gradio Chat Interface** - - A Gradio interface is created with `ChatInterface` in `messages` mode. - - This automatically provides a chat-style UI with user/assistant message bubbles and a send button. - - The title and description help set context for end users: *“Ask about automation, chatbots, dashboards, or voice AI.”* - - -## Key Features -- **Business-specific persona:** The assistant is contextualized as a sales consultant rather than a generic chatbot. -- **Adaptive responses:** System prompt is adjusted based on keywords like "price" and "integration". -- **Streaming output:** Responses are displayed incrementally, improving user experience. -- **Clean chat UI:** Built with Gradio’s `ChatInterface` for simplicity and usability. - - -This project demonstrates how to combine **system prompts**, **dynamic context handling**, and **Gradio chat interfaces** to build a specialized AI assistant tailored for business use cases. diff --git a/week1/community-contributions/fernando/day2.ipynb b/week1/community-contributions/fernando/day2.ipynb new file mode 100644 index 0000000..4a6e7b5 --- /dev/null +++ b/week1/community-contributions/fernando/day2.ipynb @@ -0,0 +1,494 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "d15d8294-3328-4e07-ad16-8a03e9bbfdb9", + "metadata": {}, + "source": [ + "# Welcome to the Day 2 Lab!\n" + ] + }, + { + "cell_type": "markdown", + "id": "ada885d9-4d42-4d9b-97f0-74fbbbfe93a9", + "metadata": {}, + "source": [ + "\n", + " \n", + " \n", + " \n", + " \n", + "
\n", + " \n", + " \n", + "

Just before we get started --

\n", + " I thought I'd take a second to point you at this page of useful resources for the course. This includes links to all the slides.
\n", + " https://edwarddonner.com/2024/11/13/llm-engineering-resources/
\n", + " Please keep this bookmarked, and I'll continue to add more useful links there over time.\n", + "
\n", + "
" + ] + }, + { + "cell_type": "markdown", + "id": "79ffe36f", + "metadata": {}, + "source": [ + "## First - let's talk about the Chat Completions API\n", + "\n", + "1. The simplest way to call an LLM\n", + "2. It's called Chat Completions because it's saying: \"here is a conversation, please predict what should come next\"\n", + "3. The Chat Completions API was invented by OpenAI, but it's so popular that everybody uses it!\n", + "\n", + "### We will start by calling OpenAI again - but don't worry non-OpenAI people, your time is coming!\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e38f17a0", + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "from dotenv import load_dotenv\n", + "\n", + "load_dotenv(override=True)\n", + "api_key = os.getenv('OPENAI_API_KEY')\n", + "\n", + "if not api_key:\n", + " print(\"No API key was found - please head over to the troubleshooting notebook in this folder to identify & fix!\")\n", + "elif not api_key.startswith(\"sk-proj-\"):\n", + " print(\"An API key was found, but it doesn't start sk-proj-; please check you're using the right key - see troubleshooting notebook\")\n", + "else:\n", + " print(\"API key found and looks good so far!\")\n" + ] + }, + { + "cell_type": "markdown", + "id": "97846274", + "metadata": {}, + "source": [ + "## Do you know what an Endpoint is?\n", + "\n", + "If not, please review the Technical Foundations guide in the guides folder\n", + "\n", + "And, here is an endpoint that might interest you..." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5af5c188", + "metadata": {}, + "outputs": [], + "source": [ + "import requests\n", + "\n", + "headers = {\"Authorization\": f\"Bearer {api_key}\", \"Content-Type\": \"application/json\"}\n", + "\n", + "payload = {\n", + " \"model\": \"gpt-5-nano\",\n", + " \"messages\": [\n", + " {\"role\": \"user\", \"content\": \"Tell me a fun fact\"}]\n", + "}\n", + "\n", + "payload" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2d0ab242", + "metadata": {}, + "outputs": [], + "source": [ + "response = requests.post(\n", + " \"https://api.openai.com/v1/chat/completions\",\n", + " headers=headers,\n", + " json=payload\n", + ")\n", + "\n", + "response.json()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "cb11a9f6", + "metadata": {}, + "outputs": [], + "source": [ + "response.json()[\"choices\"][0][\"message\"][\"content\"]" + ] + }, + { + "cell_type": "markdown", + "id": "cea3026a", + "metadata": {}, + "source": [ + "# What is the openai package?\n", + "\n", + "It's known as a Python Client Library.\n", + "\n", + "It's nothing more than a wrapper around making this exact call to the http endpoint.\n", + "\n", + "It just allows you to work with nice Python code instead of messing around with janky json objects.\n", + "\n", + "But that's it. It's open-source and lightweight. Some people think it contains OpenAI model code - it doesn't!\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "490fdf09", + "metadata": {}, + "outputs": [], + "source": [ + "# Create OpenAI client\n", + "\n", + "from openai import OpenAI\n", + "openai = OpenAI()\n", + "\n", + "response = openai.chat.completions.create(model=\"gpt-5-nano\", messages=[{\"role\": \"user\", \"content\": \"Tell me a fun fact\"}])\n", + "\n", + "response.choices[0].message.content\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "id": "c7739cda", + "metadata": {}, + "source": [ + "## And then this great thing happened:\n", + "\n", + "OpenAI's Chat Completions API was so popular, that the other model providers created endpoints that are identical.\n", + "\n", + "They are known as the \"OpenAI Compatible Endpoints\".\n", + "\n", + "For example, google made one here: https://generativelanguage.googleapis.com/v1beta/openai/\n", + "\n", + "And OpenAI decided to be kind: they said, hey, you can just use the same client library that we made for GPT. We'll allow you to specify a different endpoint URL and a different key, to use another provider.\n", + "\n", + "So you can use:\n", + "\n", + "```python\n", + "gemini = OpenAI(base_url=\"https://generativelanguage.googleapis.com/v1beta/openai/\", api_key=\"AIz....\")\n", + "gemini.chat.completions.create(...)\n", + "```\n", + "\n", + "And to be clear - even though OpenAI is in the code, we're only using this lightweight python client library to call the endpoint - there's no OpenAI model involved here.\n", + "\n", + "If you're confused, please review Guide 9 in the Guides folder!\n", + "\n", + "And now let's try it!" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f74293bc", + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "GEMINI_BASE_URL = \"https://generativelanguage.googleapis.com/v1beta/openai/\"\n", + "\n", + "google_api_key = os.getenv(\"GOOGLE_API_KEY\")\n", + "\n", + "if not google_api_key:\n", + " print(\"No API key was found - please head over to the troubleshooting notebook in this folder to identify & fix!\")\n", + "elif not google_api_key.startswith(\"AIz\"):\n", + " print(\"An API key was found, but it doesn't start AIz\")\n", + "else:\n", + " print(\"API key found and looks good so far!\")\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8fc5520d", + "metadata": {}, + "outputs": [], + "source": [ + "import google.generativeai as genai\n", + "from dotenv import load_dotenv\n", + "import os\n", + "\n", + "load_dotenv()\n", + "genai.configure(api_key=os.getenv(\"GOOGLE_API_KEY\"))\n", + "\n", + "# Lista de modelos disponibles\n", + "for model in genai.list_models():\n", + " print(model.name, \"-\", model.supported_generation_methods)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d060f484", + "metadata": {}, + "outputs": [], + "source": [ + "import google.generativeai as genai\n", + "from dotenv import load_dotenv\n", + "import os\n", + "\n", + "load_dotenv()\n", + "genai.configure(api_key=os.getenv(\"GOOGLE_API_KEY\"))\n", + "\n", + "model = genai.GenerativeModel(\"models/gemini-2.5-pro\") # Usa el modelo que viste en la lista, ejemplo \"gemini-1.5-pro\" o \"gemini-1.5-flash\"\n", + "response = model.generate_content(\"Tell me a fun fact\")\n", + "\n", + "print(response.text)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "gemini = OpenAI(base_url=GEMINI_BASE_URL, api_key=google_api_key)\n", + "\n", + "response = gemini.chat.completions.create(model=\"models/gemini-2.5-pro\", messages=[{\"role\": \"user\", \"content\": \"Tell me a fun fact\"}])\n", + "\n", + "response.choices[0].message.content" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a5b069be", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "markdown", + "id": "65272432", + "metadata": {}, + "source": [ + "## And Ollama also gives an OpenAI compatible endpoint\n", + "\n", + "...and it's on your local machine!\n", + "\n", + "If the next cell doesn't print \"Ollama is running\" then please open a terminal and run `ollama serve`" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f06280ad", + "metadata": {}, + "outputs": [], + "source": [ + "requests.get(\"http://localhost:11434\").content" + ] + }, + { + "cell_type": "markdown", + "id": "c6ef3807", + "metadata": {}, + "source": [ + "### Download llama3.2 from meta\n", + "\n", + "Change this to llama3.2:1b if your computer is smaller.\n", + "\n", + "Don't use llama3.3 or llama4! They are too big for your computer.." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e633481d", + "metadata": {}, + "outputs": [], + "source": [ + "!ollama pull llama3.2" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ce240975", + "metadata": {}, + "outputs": [], + "source": [ + "import requests\n", + "response = requests.get(\"http://localhost:11434/v1/models\")\n", + "print(response.json())\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d9419762", + "metadata": {}, + "outputs": [], + "source": [ + "from openai import OpenAI\n", + "\n", + "OLLAMA_BASE_URL = \"http://localhost:11434/v1\"\n", + "\n", + "ollama = OpenAI(base_url=OLLAMA_BASE_URL, api_key='ollama')\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e2456cdf", + "metadata": {}, + "outputs": [], + "source": [ + "# Get a fun fact\n", + "\n", + "response = ollama.chat.completions.create(model=\"llama3.2\", messages=[{\"role\": \"user\", \"content\": \"Tell me a fun fact\"}])\n", + "\n", + "response.choices[0].message.content" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3d7cebd7", + "metadata": {}, + "outputs": [], + "source": [ + "# Now let's try deepseek-r1:1.5b - this is DeepSeek \"distilled\" into Qwen from Alibaba Cloud\n", + "\n", + "!ollama pull deepseek-r1:1.5b" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "25002f25", + "metadata": {}, + "outputs": [], + "source": [ + "#response = ollama.chat.completions.create(model=\"deepseek-r1:1.5b\", messages=[{\"role\": \"user\", \"content\": \"Tell me a fun fact\"}])\n", + "#response.choices[0].message.content\n", + "\n", + "from ollama import chat # pip install ollama\n", + "\n", + "resp = chat(\n", + " model='deepseek-r1:1.5b',\n", + " messages=[{'role': 'user', 'content': 'Tell me a fun fact'}],\n", + ")\n", + "\n", + "print(resp['message']['content'])\n", + "# o\n", + "print(resp.message.content)\n" + ] + }, + { + "cell_type": "markdown", + "id": "6e9fa1fc-eac5-4d1d-9be4-541b3f2b3458", + "metadata": {}, + "source": [ + "# HOMEWORK EXERCISE ASSIGNMENT\n", + "\n", + "Upgrade the day 1 project to summarize a webpage to use an Open Source model running locally via Ollama rather than OpenAI\n", + "\n", + "You'll be able to use this technique for all subsequent projects if you'd prefer not to use paid APIs.\n", + "\n", + "**Benefits:**\n", + "1. No API charges - open-source\n", + "2. Data doesn't leave your box\n", + "\n", + "**Disadvantages:**\n", + "1. Significantly less power than Frontier Model\n", + "\n", + "## Recap on installation of Ollama\n", + "\n", + "Simply visit [ollama.com](https://ollama.com) and install!\n", + "\n", + "Once complete, the ollama server should already be running locally. \n", + "If you visit: \n", + "[http://localhost:11434/](http://localhost:11434/)\n", + "\n", + "You should see the message `Ollama is running`. \n", + "\n", + "If not, bring up a new Terminal (Mac) or Powershell (Windows) and enter `ollama serve` \n", + "And in another Terminal (Mac) or Powershell (Windows), enter `ollama pull llama3.2` \n", + "Then try [http://localhost:11434/](http://localhost:11434/) again.\n", + "\n", + "If Ollama is slow on your machine, try using `llama3.2:1b` as an alternative. Run `ollama pull llama3.2:1b` from a Terminal or Powershell, and change the code from `MODEL = \"llama3.2\"` to `MODEL = \"llama3.2:1b\"`" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6de38216-6d1c-48c4-877b-86d403f4e0f8", + "metadata": {}, + "outputs": [], + "source": [ + "# imports\n", + "import os\n", + "from dotenv import load_dotenv\n", + "from scraper import fetch_website_contents\n", + "from IPython.display import Markdown, display\n", + "from ollama import Client \n", + "\n", + "# Cliente Ollama local\n", + "ollama = Client()\n", + "\n", + "system_prompt = \"\"\"\n", + "You are a helpful assistant that analyzes the contents of a website,\n", + "and provides a short, snarky, humorous summary, ignoring text that might be navigation related.\n", + "Respond in markdown. Do not wrap the markdown in a code block - respond just with the markdown.\n", + "\"\"\"\n", + "\n", + "user_prompt_prefix = \"\"\"\n", + "Here are the contents of a website.\n", + "Provide a short summary of this website.\n", + "If it includes news or announcements, then summarize these too.\n", + "\"\"\"\n", + "\n", + "def messages_for(website):\n", + " return [\n", + " {\"role\": \"system\", \"content\": system_prompt},\n", + " {\"role\": \"user\", \"content\": user_prompt_prefix + website}\n", + " ]\n", + "\n", + "def summarize(url):\n", + " website = fetch_website_contents(url)\n", + " response = ollama.chat(\n", + " model='llama3.2',\n", + " messages=messages_for(website)\n", + " )\n", + " return response['message']['content']\n", + "\n", + "def display_summary(url):\n", + " summary = summarize(url)\n", + " display(Markdown(summary))\n", + "\n", + "# Ejecuta el resumen\n", + "display_summary(\"https://www.reforma.com\")\n" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": ".venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.12" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/week1/community-contributions/fernando/week1 EXERCISE.ipynb b/week1/community-contributions/fernando/week1 EXERCISE.ipynb new file mode 100644 index 0000000..c152cb7 --- /dev/null +++ b/week1/community-contributions/fernando/week1 EXERCISE.ipynb @@ -0,0 +1,175 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "fe12c203-e6a6-452c-a655-afb8a03a4ff5", + "metadata": {}, + "source": [ + "# End of week 1 exercise\n", + "\n", + "To demonstrate your familiarity with OpenAI API, and also Ollama, build a tool that takes a technical question, \n", + "and responds with an explanation. This is a tool that you will be able to use yourself during the course!" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c1070317-3ed9-4659-abe3-828943230e03", + "metadata": {}, + "outputs": [], + "source": [ + "# imports\n", + "import os\n", + "from openai import OpenAI\n", + "from dotenv import load_dotenv" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4a456906-915a-4bfd-bb9d-57e505c5093f", + "metadata": {}, + "outputs": [], + "source": [ + "# constants\n", + "MODEL_GPT = 'gpt-4o-mini'\n", + "MODEL_LLAMA = 'llama3.2'" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a8d7923c-5f28-4c30-8556-342d7c8497c1", + "metadata": {}, + "outputs": [], + "source": [ + "# set up environment\n", + "system_prompt = \"\"\"\n", + "You are a technical expert of AI and LLMs.\n", + "\"\"\"\n", + "\n", + "user_prompt_prefix = \"\"\"\n", + "Provide deep explanations of the provided text.\n", + "\"\"\"\n", + "\n", + "user_prompt = \"\"\"\n", + "Explain the provided text.\n", + "\"\"\"\n", + "client = OpenAI()\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3f0d0137-52b0-47a8-81a8-11a90a010798", + "metadata": {}, + "outputs": [], + "source": [ + "# here is the question; type over this to ask something new\n", + "\n", + "question = \"\"\"\n", + "Ollama does have an OpenAI compatible endpoint, but Gemini doesn't?\n", + "\"\"\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Get gpt-4o-mini to answer, with streaming\n", + "def messages_for(question):\n", + " return [\n", + " {\"role\": \"system\", \"content\": system_prompt},\n", + " {\"role\": \"user\", \"content\": user_prompt_prefix + question}\n", + " ]\n", + "\n", + "def run_model_streaming(model_name, question):\n", + " stream = client.chat.completions.create(\n", + " model=model_name,\n", + " messages=messages_for(question),\n", + " stream=True\n", + " )\n", + " for chunk in stream:\n", + " content = chunk.choices[0].delta.content\n", + " if content:\n", + " print(content, end=\"\", flush=True)\n", + "\n", + "run_model_streaming(MODEL_GPT, question)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8f7c8ea8-4082-4ad0-8751-3301adcf6538", + "metadata": {}, + "outputs": [], + "source": [ + "# Get Llama 3.2 to answer\n", + "# imports\n", + "import os\n", + "from openai import OpenAI\n", + "from dotenv import load_dotenv\n", + "\n", + "# set up environment\n", + "client = OpenAI(\n", + " base_url=os.getenv(\"OPENAI_BASE_URL\", \"http://localhost:11434/v1\"),\n", + " api_key=os.getenv(\"OPENAI_API_KEY\", \"ollama\")\n", + ")\n", + "\n", + "system_prompt = \"\"\"\n", + "You are a technical expert of AI and LLMs.\n", + "\"\"\"\n", + "\n", + "user_prompt_prefix = \"\"\"\n", + "Provide deep explanations of the provided text.\n", + "\"\"\"\n", + "\n", + "# question\n", + "question = \"\"\"\n", + "Ollama does have an OpenAI compatible endpoint, but Gemini doesn't?\n", + "\"\"\"\n", + "\n", + "# message\n", + "def messages_for(question):\n", + " return [\n", + " {\"role\": \"system\", \"content\": system_prompt},\n", + " {\"role\": \"user\", \"content\": user_prompt_prefix + question}\n", + " ]\n", + "\n", + "# response\n", + "def run_model(model_name, question):\n", + " response = client.chat.completions.create(\n", + " model=model_name,\n", + " messages=messages_for(question)\n", + " )\n", + " return response.choices[0].message.content\n", + "\n", + "# run and print result\n", + "print(run_model(MODEL_LLAMA, question))\n" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": ".venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.12" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/week1/community-contributions/week1-google-map-review-summarizer/google-map-review-summarizer.ipynb b/week1/community-contributions/week1-google-map-review-summarizer/google-map-review-summarizer.ipynb new file mode 100644 index 0000000..d18222b --- /dev/null +++ b/week1/community-contributions/week1-google-map-review-summarizer/google-map-review-summarizer.ipynb @@ -0,0 +1,367 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "1fecd49e", + "metadata": {}, + "source": [ + "# 🗺️ Google Maps Review Summarizer\n", + "\n", + "This Python app automates the process of fetching and summarizing Google Maps reviews for any business or location.\n", + "\n", + "## 🚀 Overview\n", + "The app performs two main tasks:\n", + "1. **Scrape Reviews** – Uses a web scraping script to extract reviews directly from Google Maps.\n", + "2. **Summarize Content** – Leverages OpenAI's language models to generate concise, insightful summaries of the collected reviews and analyse the sentiments.\n", + "\n", + "## 🧠 Tech Stack\n", + "- **Python** – Core language\n", + "- **Playwright** – For scraping reviews\n", + "- **OpenAI API** – For natural language summarization\n", + "- **Jupyter Notebook** – For exploration, testing, and demonstration\n", + "\n", + "### 🙏 Credits\n", + "The web scraping logic is **inspired by [Antonello Zanini’s blog post](https://blog.apify.com/how-to-scrape-google-reviews/)** on building a Google Reviews scraper. Special thanks for the valuable insights on **structuring and automating the scraping workflow**, which greatly informed the development of this improved scraper.\n", + "\n", + "This app, however, uses an **enhanced version of the scraper** that can scroll infinitely to load more reviews until it collects **at least 1,000 reviews**. If only a smaller number of reviews are available, the scraper stops scrolling earlier.\n", + "\n", + "## ✅ Sample Output\n", + "Here is a summary of reviews of a restuarant generated by the app.\n", + "\n", + "![Alt text](google-map-review-summary.jpg)\n", + "\n", + "\n", + "---\n", + "\n", + "**Note:** This project is intended for educational and research purposes. Please ensure compliance with Google’s [Terms of Service](https://policies.google.com/terms) when scraping or using their data.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "df04a4aa", + "metadata": {}, + "outputs": [], + "source": [ + "#Activate the llm_engineering virtual environment\n", + "!source ../../../.venv/bin/activate \n", + "\n", + "#Make sure pip is available and up to date inside the venv\n", + "!python3 -m ensurepip --upgrade\n", + "\n", + "#Verify that pip now points to the venv path (should end with /.venv/bin/pip)\n", + "!which pip3\n", + "\n", + "#Install Playwright inside the venv\n", + "!pip3 install playwright\n", + "\n", + "#Download the required browser binaries and dependencies\n", + "!python3 -m playwright install" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "1c794cfd", + "metadata": {}, + "outputs": [], + "source": [ + "import asyncio\n", + "from playwright.async_api import async_playwright\n", + "from IPython.display import Markdown, display\n", + "import os\n", + "from dotenv import load_dotenv\n", + "from openai import OpenAI\n" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "317af2b8", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "API key found and looks good so far!\n" + ] + } + ], + "source": [ + "# Load environment variables in a file called .env\n", + "\n", + "load_dotenv(override=True)\n", + "api_key = os.getenv('OPENAI_API_KEY')\n", + "\n", + "# Check the key\n", + "\n", + "if not api_key:\n", + " print(\"No API key was found - please head over to the troubleshooting notebook in this folder to identify & fix!\")\n", + "elif not api_key.startswith(\"sk-proj-\"):\n", + " print(\"An API key was found, but it doesn't start sk-proj-; please check you're using the right key - see troubleshooting notebook\")\n", + "elif api_key.strip() != api_key:\n", + " print(\"An API key was found, but it looks like it might have space or tab characters at the start or end - please remove them - see troubleshooting notebook\")\n", + "else:\n", + " print(\"API key found and looks good so far!\")\n" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "6f142c79", + "metadata": {}, + "outputs": [], + "source": [ + "async def scroll_reviews_panel(page, max_scrolls=50, max_reviews=10):\n", + " \"\"\"\n", + " Scrolls through the reviews panel to lazy load all reviews.\n", + " \n", + " Args:\n", + " page: Playwright page object\n", + " max_scrolls: Maximum number of scroll attempts to prevent infinite loops\n", + " \n", + " Returns:\n", + " Number of reviews loaded\n", + " \"\"\"\n", + " # Find the scrollable reviews container\n", + " # Google Maps reviews are in a specific scrollable div\n", + " scrollable_div = page.locator('div[role=\"main\"] div[jslog$=\"mutable:true;\"]').first\n", + " \n", + " previous_review_count = 0\n", + " scroll_attempts = 0\n", + " no_change_count = 0\n", + "\n", + " print(\"Starting to scroll and load reviews...\")\n", + " \n", + " while scroll_attempts < max_scrolls:\n", + " # Get current count of reviews\n", + " review_elements = page.locator(\"div[data-review-id][jsaction]\")\n", + " current_review_count = await review_elements.count()\n", + " \n", + " #if we have loaded max_reviews, we will stop scrolling\n", + " if current_review_count >= max_reviews:\n", + " break\n", + "\n", + " print(f\"Scroll attempt {scroll_attempts + 1}: Found {current_review_count} reviews\")\n", + " \n", + " # Scroll to the bottom of the reviews panel\n", + " await scrollable_div.evaluate(\"\"\"\n", + " (element) => {\n", + " element.scrollTo(0, element.scrollHeight + 100);\n", + " }\n", + " \"\"\")\n", + " \n", + " # Wait for potential new content to load\n", + " await asyncio.sleep(2)\n", + " \n", + " # Check if new reviews were loaded\n", + " if current_review_count == previous_review_count:\n", + " no_change_count += 1\n", + " # If count hasn't changed for 3 consecutive scrolls, we've likely reached the end\n", + " if no_change_count >= 3:\n", + " print(f\"No new reviews loaded after {no_change_count} attempts. Finished loading.\")\n", + " break\n", + " else:\n", + " no_change_count = 0\n", + " \n", + " previous_review_count = current_review_count\n", + " scroll_attempts += 1\n", + " \n", + " final_count = await review_elements.count()\n", + " print(f\"Finished scrolling. Total reviews loaded: {final_count}\")\n", + " return final_count" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "f7f67b70", + "metadata": {}, + "outputs": [], + "source": [ + "async def scrape_google_reviews(url):\n", + " # Where to store the scraped data\n", + " reviews = []\n", + "\n", + " async with async_playwright() as p:\n", + " # Initialize a new Playwright instance\n", + " browser = await p.chromium.launch(\n", + " headless=True # Set to False if you want to see the browser in action\n", + " )\n", + " context = await browser.new_context()\n", + " page = await context.new_page()\n", + "\n", + " # The URL of the Google Maps reviews page\n", + "\n", + " # Navigate to the target Google Maps page\n", + " print(\"Navigating to Google Maps page...\")\n", + " await page.goto(url)\n", + "\n", + " # Wait for initial reviews to load\n", + " print(\"Waiting for initial reviews to load...\")\n", + " review_html_elements = page.locator(\"div[data-review-id][jsaction]\")\n", + " await review_html_elements.first.wait_for(state=\"visible\", timeout=10000)\n", + " \n", + " # Scroll through the reviews panel to lazy load all reviews\n", + " total_reviews = await scroll_reviews_panel(page, max_scrolls=100)\n", + " \n", + " print(f\"\\nStarting to scrape {total_reviews} reviews...\")\n", + "\n", + " # Get all review elements after scrolling\n", + " review_html_elements = page.locator(\"div[data-review-id][jsaction]\")\n", + " all_reviews = await review_html_elements.all()\n", + " \n", + " # Iterate over the elements and scrape data from each of them\n", + " for idx, review_html_element in enumerate(all_reviews, 1):\n", + " try:\n", + " # Scraping logic\n", + "\n", + " stars_element = review_html_element.locator(\"[aria-label*=\\\"star\\\"]\")\n", + " stars_label = await stars_element.get_attribute(\"aria-label\")\n", + "\n", + " # Extract the review score from the stars label\n", + " stars = None\n", + " for i in range(1, 6):\n", + " if stars_label and str(i) in stars_label:\n", + " stars = i\n", + " break\n", + "\n", + " # Get the next sibling of the previous element with an XPath expression\n", + " time_sibling = stars_element.locator(\"xpath=following-sibling::span\")\n", + " time = await time_sibling.text_content()\n", + "\n", + " # Select the \"More\" button and if it is present, click it\n", + " more_element = review_html_element.locator(\"button[aria-label=\\\"See more\\\"]\").first\n", + " if await more_element.is_visible():\n", + " await more_element.click()\n", + " await asyncio.sleep(0.3) # Brief wait for text expansion\n", + "\n", + " text_element = review_html_element.locator(\"div[tabindex=\\\"-1\\\"][id][lang]\")\n", + " text = await text_element.text_content()\n", + "\n", + " reviews.append(str(stars) + \" Stars: \\n\" +\"Reviewed On:\" + time + \"\\n\"+ text)\n", + " \n", + " if idx % 10 == 0:\n", + " print(f\"Scraped {idx}/{total_reviews} reviews...\")\n", + " \n", + " except Exception as e:\n", + " print(f\"Error scraping review {idx}: {str(e)}\")\n", + " continue\n", + "\n", + " print(f\"\\nSuccessfully scraped {len(reviews)} reviews!\")\n", + "\n", + " # Close the browser and release its resources\n", + " await browser.close()\n", + "\n", + " return \"\\n\".join(reviews)" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "cb160d5f", + "metadata": {}, + "outputs": [], + "source": [ + "system_prompt = \"\"\"\n", + "You are an expert assistant that analyzes google reviews,\n", + "and provides a summary and centiment of the reviews.\n", + "Respond in markdown. Do not wrap the markdown in a code block - respond just with the markdown.\n", + "\"\"\"" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "69e08d4b", + "metadata": {}, + "outputs": [], + "source": [ + "# Define our user prompt\n", + "\n", + "user_prompt_prefix = \"\"\"\n", + "Here are the reviews of a google map location/business.\n", + "Provide a short summary of the reviews and the sentiment of the reviews.\n", + "\n", + "\"\"\"" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "d710972d", + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "def prepare_message(reviews):\n", + " return [\n", + " {\"role\": \"system\", \"content\": system_prompt},\n", + " {\"role\": \"user\", \"content\": user_prompt_prefix + reviews}\n", + " ]" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "cb51f436", + "metadata": {}, + "outputs": [], + "source": [ + "async def summarize(url):\n", + " openai = OpenAI()\n", + " reviews = await scrape_google_reviews(url)\n", + " response = openai.chat.completions.create(\n", + " model = \"gpt-4.1-mini\",\n", + " messages = prepare_message(reviews)\n", + " )\n", + " return response.choices[0].message.content" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "2f09e2d2", + "metadata": {}, + "outputs": [], + "source": [ + "async def display_summary(url):\n", + " summary = await summarize(url)\n", + " display(Markdown(summary))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ca7995c9", + "metadata": {}, + "outputs": [], + "source": [ + "url = \"https://www.google.com/maps/place/Grace+Home+Nursing+%26+Assisted+Living/@12.32184,75.0853037,17z/data=!4m8!3m7!1s0x3ba47da1be6a0279:0x9e73181ab0827f7e!8m2!3d12.32184!4d75.0853037!9m1!1b1!16s%2Fg%2F11qjl430n_?entry=ttu&g_ep=EgoyMDI1MTAyMC4wIKXMDSoASAFQAw%3D%3D\"\n", + "await display_summary(url)\n" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": ".venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.12" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/week1/community-contributions/week1-google-map-review-summarizer/google-map-review-summary.jpg b/week1/community-contributions/week1-google-map-review-summarizer/google-map-review-summary.jpg new file mode 100644 index 0000000..43a7891 Binary files /dev/null and b/week1/community-contributions/week1-google-map-review-summarizer/google-map-review-summary.jpg differ diff --git a/week2/community-contributions/ai_domain_finder/ai_domain_finder.ipynb b/week2/community-contributions/ai_domain_finder/ai_domain_finder.ipynb new file mode 100644 index 0000000..c0fbbcc --- /dev/null +++ b/week2/community-contributions/ai_domain_finder/ai_domain_finder.ipynb @@ -0,0 +1,721 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "1633a440", + "metadata": {}, + "outputs": [], + "source": [ + "\"\"\"\n", + "Week 2 Assignment: LLM Engineering\n", + "Author: Nikhil Raut\n", + "\n", + "Notebook: ai_domain_finder.ipynb\n", + "\n", + "Purpose:\n", + "Build an agentic AI Domain Finder that proposes short, brandable .com names, verifies availability via RDAP, \n", + "then returns: \n", + " a list of available .coms, \n", + " one preferred pick, \n", + " and a brief audio rationale.\n", + "\"\"\"\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "da528fbe", + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "import json\n", + "import requests\n", + "from typing import Dict, List, Tuple, Any, Optional\n", + "import re\n", + "\n", + "from dotenv import load_dotenv\n", + "from openai import OpenAI\n", + "import gradio as gr\n", + "\n", + "load_dotenv(override=True)\n", + "\n", + "OPENAI_MODEL = \"gpt-5-nano-2025-08-07\"\n", + "TTS_MODEL = \"gpt-4o-mini-tts\"\n", + "\n", + "openai = OpenAI()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "361f7fe3", + "metadata": {}, + "outputs": [], + "source": [ + "# --- robust logging that works inside VS Code notebooks + Gradio threads ---\n", + "import sys, logging, threading\n", + "from collections import deque\n", + "from typing import Any\n", + "\n", + "DEBUG_LLM = True # toggle on/off noisy logs\n", + "CLEAR_LOG_ON_RUN = True # clear panel before each submit\n", + "\n", + "_LOG_BUFFER = deque(maxlen=2000) # keep ~2000 lines in memory\n", + "_LOG_LOCK = threading.Lock()\n", + "\n", + "class GradioBufferHandler(logging.Handler):\n", + " def emit(self, record: logging.LogRecord) -> None:\n", + " try:\n", + " msg = self.format(record)\n", + " except Exception:\n", + " msg = record.getMessage()\n", + " with _LOG_LOCK:\n", + " for line in (msg.splitlines() or [\"\"]):\n", + " _LOG_BUFFER.append(line)\n", + "\n", + "def get_log_text() -> str:\n", + " with _LOG_LOCK:\n", + " return \"\\n\".join(_LOG_BUFFER)\n", + "\n", + "def clear_log_buffer() -> None:\n", + " with _LOG_LOCK:\n", + " _LOG_BUFFER.clear()\n", + "\n", + "def _setup_logger() -> logging.Logger:\n", + " logger = logging.getLogger(\"aidf\")\n", + " logger.setLevel(logging.DEBUG if DEBUG_LLM else logging.INFO)\n", + " logger.handlers.clear()\n", + " fmt = logging.Formatter(\"%(asctime)s | %(levelname)s | %(message)s\", \"%H:%M:%S\")\n", + "\n", + " stream = logging.StreamHandler(stream=sys.stdout) # captured by VS Code notebook\n", + " stream.setFormatter(fmt)\n", + "\n", + " buf = GradioBufferHandler() # shown inside the Gradio panel\n", + " buf.setFormatter(fmt)\n", + "\n", + " logger.addHandler(stream)\n", + " logger.addHandler(buf)\n", + " logger.propagate = False\n", + " return logger\n", + "\n", + "logger = _setup_logger()\n", + "\n", + "def dbg_json(obj: Any, title: str = \"\") -> None:\n", + " \"\"\"Convenience: pretty-print JSON-ish objects to the logger.\"\"\"\n", + " try:\n", + " txt = json.dumps(obj, ensure_ascii=False, indent=2)\n", + " except Exception:\n", + " txt = str(obj)\n", + " if title:\n", + " logger.debug(\"%s\\n%s\", title, txt)\n", + " else:\n", + " logger.debug(\"%s\", txt)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "519674b2", + "metadata": {}, + "outputs": [], + "source": [ + "RDAP_URL = \"https://rdap.verisign.com/com/v1/domain/{}\"\n", + "\n", + "_ALPHA_RE = re.compile(r\"^[a-z]+$\", re.IGNORECASE)\n", + "\n", + "def _to_com(domain: str) -> str:\n", + " d = domain.strip().lower()\n", + " return d if d.endswith(\".com\") else f\"{d}.com\"\n", + "\n", + "def _sld_is_english_alpha(fqdn: str) -> bool:\n", + " \"\"\"\n", + " True only if the second-level label (just before .com) is made up\n", + " exclusively of English letters (a-z).\n", + " Examples:\n", + " foo.com -> True\n", + " foo-bar.com -> False\n", + " foo1.com -> False\n", + " café.com -> False\n", + " xn--cafe.com -> False\n", + " www.foo.com -> True (checks 'foo')\n", + " \"\"\"\n", + " if not fqdn.endswith(\".com\"):\n", + " return False\n", + " sld = fqdn[:-4].split(\".\")[-1] # take label immediately before .com\n", + " return bool(sld) and bool(_ALPHA_RE.fullmatch(sld))\n", + "\n", + "def check_com_availability(domain: str) -> Dict:\n", + " fqdn = _to_com(domain)\n", + " # Skip API if not strictly English letters\n", + " if not _sld_is_english_alpha(fqdn):\n", + " return {\"domain\": fqdn, \"available\": False, \"status\": 0}\n", + "\n", + " try:\n", + " r = requests.get(RDAP_URL.format(fqdn), timeout=6)\n", + " return {\"domain\": fqdn, \"available\": (r.status_code == 404), \"status\": r.status_code}\n", + " except requests.RequestException:\n", + " return {\"domain\": fqdn, \"available\": False, \"status\": 0}\n", + "\n", + "def check_com_availability_bulk(domains: List[str]) -> Dict:\n", + " \"\"\"\n", + " Input: list of domain roots or FQDNs.\n", + " Returns:\n", + " {\n", + " \"results\": [{\"domain\": \"...\", \"available\": bool, \"status\": int}, ...],\n", + " \"available\": [\"...\"], # convenience\n", + " \"count_available\": int\n", + " }\n", + " \"\"\"\n", + " session = requests.Session()\n", + " results: List[Dict] = []\n", + "\n", + " for d in domains:\n", + " fqdn = _to_com(d)\n", + "\n", + " # Skip API if not strictly English letters\n", + " if not _sld_is_english_alpha(fqdn):\n", + " results.append({\"domain\": fqdn, \"available\": False, \"status\": 0})\n", + " continue\n", + "\n", + " try:\n", + " r = session.get(RDAP_URL.format(fqdn), timeout=6)\n", + " ok = (r.status_code == 404)\n", + " results.append({\"domain\": fqdn, \"available\": ok, \"status\": r.status_code})\n", + " except requests.RequestException:\n", + " results.append({\"domain\": fqdn, \"available\": False, \"status\": 0})\n", + "\n", + " available = [x[\"domain\"] for x in results if x[\"available\"]]\n", + " return {\"results\": results, \"available\": available, \"count_available\": len(available)}\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "cd20c262", + "metadata": {}, + "outputs": [], + "source": [ + "check_tool_bulk = {\n", + " \"type\": \"function\",\n", + " \"function\": {\n", + " \"name\": \"check_com_availability_bulk\",\n", + " \"description\": \"Batch check .com availability via RDAP for a list of domains (roots or FQDNs).\",\n", + " \"parameters\": {\n", + " \"type\": \"object\",\n", + " \"properties\": {\n", + " \"domains\": {\n", + " \"type\": \"array\",\n", + " \"items\": {\"type\": \"string\"},\n", + " \"minItems\": 1,\n", + " \"maxItems\": 50,\n", + " \"description\": \"List of domain roots or .com FQDNs.\"\n", + " }\n", + " },\n", + " \"required\": [\"domains\"],\n", + " \"additionalProperties\": False\n", + " }\n", + " }\n", + "}\n", + "\n", + "TOOLS = [check_tool_bulk]\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2a9138b6", + "metadata": {}, + "outputs": [], + "source": [ + "def handle_tool_calls(message) -> List[Dict]:\n", + " results = []\n", + " for call in (message.tool_calls or []):\n", + " fn = getattr(call.function, \"name\", None)\n", + " args_raw = getattr(call.function, \"arguments\", \"\") or \"{}\"\n", + " try:\n", + " args = json.loads(args_raw)\n", + " except Exception:\n", + " args = {}\n", + "\n", + " logger.debug(\"TOOL CALL -> %s | args=%s\", fn, json.dumps(args, ensure_ascii=False))\n", + "\n", + " if fn == \"check_com_availability_bulk\":\n", + " payload = check_com_availability_bulk(args.get(\"domains\", []))\n", + " elif fn == \"check_com_availability\":\n", + " payload = check_com_availability(args.get(\"domain\", \"\"))\n", + " else:\n", + " payload = {\"error\": f\"unknown tool {fn}\"}\n", + "\n", + " logger.debug(\"TOOL RESULT <- %s | %s\", fn, json.dumps(payload, ensure_ascii=False))\n", + "\n", + " results.append({\n", + " \"role\": \"tool\",\n", + " \"tool_call_id\": call.id,\n", + " \"content\": json.dumps(payload),\n", + " })\n", + " return results\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0b80c860", + "metadata": {}, + "outputs": [], + "source": [ + "SYSTEM_PROMPT = \"\"\"You are the Agent for project \"AI Domain Finder\".\n", + "Goal: suggest .com domains and verify availability using the tool ONLY (no guessing).\n", + "\n", + "Do this each interaction:\n", + "- Generate up to ~20 short, brandable .com candidates from:\n", + " (1) Industry, (2) Target Customers, (3) Description.\n", + "- Use the BULK tool `check_com_availability_bulk` with a list of candidates\n", + " (roots or FQDNs). Prefer a single call or very few batched calls.\n", + "- If >= 5 available .coms are found, STOP checking and finalize the answer.\n", + "\n", + "Output Markdown with EXACT section headings:\n", + "1) Available .com domains:\n", + " - itemized list of available .coms only (root + .com)\n", + "2) Preferred domain:\n", + " - a single best pick\n", + "3) Audio explanation:\n", + " - 1–2 concise sentences explaining the preference\n", + "\n", + "Constraints:\n", + "- Use customer-familiar words where helpful.\n", + "- Keep names short, simple, pronounceable; avoid hyphens/numbers unless meaningful.\n", + "- Never include TLDs other than .com.\n", + "- domain is made up of english alphabets in lower case only no symbols or spaces to use\n", + "\"\"\"\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "72e9d8c2", + "metadata": {}, + "outputs": [], + "source": [ + "def _asdict_tool_call(tc: Any) -> dict:\n", + " try:\n", + " return {\n", + " \"id\": getattr(tc, \"id\", None),\n", + " \"type\": \"function\",\n", + " \"function\": {\n", + " \"name\": getattr(tc.function, \"name\", None),\n", + " \"arguments\": getattr(tc.function, \"arguments\", None),\n", + " },\n", + " }\n", + " except Exception:\n", + " return {\"type\": \"function\", \"function\": {\"name\": None, \"arguments\": None}}\n", + "\n", + "def _asdict_message(msg: Any) -> dict:\n", + " if isinstance(msg, dict):\n", + " return msg\n", + " role = getattr(msg, \"role\", None)\n", + " content = getattr(msg, \"content\", None)\n", + " tool_calls = getattr(msg, \"tool_calls\", None)\n", + " out = {\"role\": role, \"content\": content}\n", + " if tool_calls:\n", + " out[\"tool_calls\"] = [_asdict_tool_call(tc) for tc in tool_calls]\n", + " return out\n", + "\n", + "def _sanitized_messages_for_log(messages: list[dict | Any]) -> list[dict]:\n", + " return [_asdict_message(m) for m in messages]\n", + "\n", + "def _limit_text(s: str, limit: int = 40000) -> str:\n", + " return s if len(s) <= limit else (s[:limit] + \"\\n... [truncated]\")\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b45c6382", + "metadata": {}, + "outputs": [], + "source": [ + "def run_agent_with_tools(history: List[Dict]) -> Tuple[str, List[str], str]:\n", + " \"\"\"\n", + " Returns:\n", + " reply_md: final assistant markdown\n", + " tool_available: .coms marked available by RDAP tools (order-preserving, deduped)\n", + " dbg_text: concatenated log buffer (for the UI panel)\n", + " \"\"\"\n", + " messages: List[Dict] = [{\"role\": \"system\", \"content\": SYSTEM_PROMPT}] + history\n", + " tool_available: List[str] = []\n", + "\n", + " dbg_json(_sanitized_messages_for_log(messages), \"=== LLM REQUEST (initial messages) ===\")\n", + " resp = openai.chat.completions.create(model=OPENAI_MODEL, messages=messages, tools=TOOLS)\n", + "\n", + " while resp.choices[0].finish_reason == \"tool_calls\":\n", + " tool_msg_sdk = resp.choices[0].message\n", + " tool_msg = _asdict_message(tool_msg_sdk)\n", + " dbg_json(tool_msg, \"=== ASSISTANT (tool_calls) ===\")\n", + "\n", + " tool_results = handle_tool_calls(tool_msg_sdk)\n", + "\n", + " # Accumulate authoritative availability directly from tool outputs\n", + " for tr in tool_results:\n", + " try:\n", + " data = json.loads(tr[\"content\"])\n", + " if isinstance(data, dict) and isinstance(data.get(\"available\"), list):\n", + " for d in data[\"available\"]:\n", + " tool_available.append(_to_com(d))\n", + " except Exception:\n", + " pass\n", + "\n", + " dbg_json([json.loads(tr[\"content\"]) for tr in tool_results], \"=== TOOL RESULTS ===\")\n", + "\n", + " messages.append(tool_msg)\n", + " messages.extend(tool_results)\n", + " dbg_json(_sanitized_messages_for_log(messages), \"=== LLM REQUEST (messages + tools) ===\")\n", + "\n", + " resp = openai.chat.completions.create(model=OPENAI_MODEL, messages=messages, tools=TOOLS)\n", + "\n", + " # Dedup preserve order\n", + " seen, uniq = set(), []\n", + " for d in tool_available:\n", + " if d not in seen:\n", + " seen.add(d)\n", + " uniq.append(d)\n", + "\n", + " reply_md = resp.choices[0].message.content\n", + " logger.debug(\"=== FINAL ASSISTANT ===\\n%s\", _limit_text(reply_md))\n", + " dbg_json(uniq, \"=== AVAILABLE FROM TOOLS (authoritative) ===\")\n", + "\n", + " # Return current buffer text for the UI panel\n", + " dbg_text = _limit_text(get_log_text(), 40000)\n", + " return reply_md, uniq, dbg_text\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "92306515", + "metadata": {}, + "outputs": [], + "source": [ + "def extract_audio_text(markdown_reply: str) -> str:\n", + " \"\"\"\n", + " Pulls the 'Audio explanation:' section; falls back to first sentence.\n", + " \"\"\"\n", + " marker = \"Audio explanation:\"\n", + " lower = markdown_reply.lower()\n", + " idx = lower.find(marker.lower())\n", + " if idx != -1:\n", + " segment = markdown_reply[idx + len(marker):].strip()\n", + " parts = segment.split(\".\")\n", + " return (\". \".join([p.strip() for p in parts if p.strip()][:2]) + \".\").strip()\n", + " return \"This domain is the clearest, most memorable fit for the audience and brand goals.\"\n", + "\n", + "def synth_audio(text: str) -> bytes:\n", + " audio = openai.audio.speech.create(\n", + " model=TTS_MODEL,\n", + " voice=\"alloy\",\n", + " input=text\n", + " )\n", + " return audio.content\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "cc6c0650", + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "_DOMAIN_RE = re.compile(r\"\\b[a-z0-9](?:[a-z0-9-]{0,61}[a-z0-9])?\\.com\\b\", re.I)\n", + "_HDR_AVAIL = re.compile(r\"^\\s*[\\d\\.\\)\\-]*\\s*available\\s+.*\\.com\\s+domains\", re.I)\n", + "_HDR_PREF = re.compile(r\"^\\s*[\\d\\.\\)\\-]*\\s*preferred\\s+domain\", re.I)\n", + "\n", + "def _norm_domain(s: str) -> str:\n", + " s = s.strip().lower()\n", + " return s if s.endswith(\".com\") else f\"{s}.com\"\n", + "\n", + "def parse_available(md: str) -> list[str]:\n", + " lines = md.splitlines()\n", + " out = []\n", + " in_section = False\n", + " for ln in lines:\n", + " if _HDR_AVAIL.search(ln):\n", + " in_section = True\n", + " continue\n", + " if in_section and _HDR_PREF.search(ln):\n", + " break\n", + " if in_section:\n", + " for m in _DOMAIN_RE.findall(ln):\n", + " out.append(_norm_domain(m))\n", + " # Fallback: if the header wasn't found, collect all .coms then we'll still\n", + " # rely on agent instruction to list only available, which should be safe.\n", + " if not out:\n", + " out = [_norm_domain(m) for m in _DOMAIN_RE.findall(md)]\n", + " # dedupe preserve order\n", + " seen, uniq = set(), []\n", + " for d in out:\n", + " if d not in seen:\n", + " seen.add(d)\n", + " uniq.append(d)\n", + " return uniq\n", + "\n", + "def parse_preferred(md: str) -> str:\n", + " # search the preferred section first\n", + " lines = md.splitlines()\n", + " start = None\n", + " for i, ln in enumerate(lines):\n", + " if _HDR_PREF.search(ln):\n", + " start = i\n", + " break\n", + " segment = \"\\n\".join(lines[start:start+8]) if start is not None else md[:500]\n", + " m = _DOMAIN_RE.search(segment)\n", + " if m:\n", + " return _norm_domain(m.group(0))\n", + " m = _DOMAIN_RE.search(md)\n", + " return _norm_domain(m.group(0)) if m else \"\"\n", + "\n", + "def merge_and_sort(old: list[str], new: list[str]) -> list[str]:\n", + " merged = {d.lower() for d in old} | {d.lower() for d in new}\n", + " return sorted(merged, key=lambda s: (len(s), s))\n", + "\n", + "def fmt_available_md(domains: list[str]) -> str:\n", + " if not domains:\n", + " return \"### Available .com domains (cumulative)\\n\\n*– none yet –*\"\n", + " items = \"\\n\".join(f\"- `{d}`\" for d in domains)\n", + " return f\"### Available .com domains (cumulative)\\n\\n{items}\"\n", + "\n", + "def fmt_preferred_md(d: str) -> str:\n", + " if not d:\n", + " return \"### Preferred domain\\n\\n*– not chosen yet –*\"\n", + " return f\"### Preferred domain\\n\\n`{d}`\"\n", + "\n", + "def build_context_msg(known_avail: Optional[List[str]], preferred_now: Optional[str]) -> str:\n", + " \"\"\"\n", + " Create a short 'state so far' block that we prepend to the next user turn\n", + " so the model always sees the preferred and cumulative available list.\n", + " \"\"\"\n", + " lines = []\n", + " if (preferred_now or \"\").strip():\n", + " lines.append(f\"Preferred domain so far: {preferred_now.strip().lower()}\")\n", + " if known_avail:\n", + " lines.append(\"Available .com domains discovered so far:\")\n", + " for d in known_avail:\n", + " if d:\n", + " lines.append(f\"- {d.strip().lower()}\")\n", + " if not lines:\n", + " return \"\"\n", + " return \"STATE TO CARRY OVER FROM PREVIOUS TURNS:\\n\" + \"\\n\".join(lines)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "07f079d6", + "metadata": {}, + "outputs": [], + "source": [ + "def run_and_extract(history: List[Dict]) -> Tuple[str, List[str], str, str, str]:\n", + " reply_md, avail_from_tools, dbg_text = run_agent_with_tools(history)\n", + " parsed_avail = parse_available(reply_md)\n", + " new_avail = merge_and_sort(avail_from_tools, parsed_avail)\n", + " preferred = parse_preferred(reply_md)\n", + " audio_text = extract_audio_text(reply_md)\n", + " return reply_md, new_avail, preferred, audio_text, dbg_text\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4cd5d8ef", + "metadata": {}, + "outputs": [], + "source": [ + "def initial_submit(industry: str, customers: str, desc: str,\n", + " history: List[Dict], known_avail: List[str], preferred_now: str):\n", + " if CLEAR_LOG_ON_RUN:\n", + " clear_log_buffer()\n", + "\n", + " logger.info(\"Initial submit | industry=%r | customers=%r | desc_len=%d\",\n", + " industry, customers, len(desc or \"\"))\n", + "\n", + " # Build context (usually empty on the very first run, but future inits also work)\n", + " ctx = build_context_msg(known_avail or [], preferred_now or \"\")\n", + "\n", + " user_msg = (\n", + " \"Please propose .com domains based on:\\n\"\n", + " f\"Industry: {industry}\\n\"\n", + " f\"Target Customers: {customers}\\n\"\n", + " f\"Description: {desc}\"\n", + " )\n", + "\n", + " # Single user turn that includes state + prompt so the model always sees memory\n", + " full_content = (ctx + \"\\n\\n\" if ctx else \"\") + user_msg\n", + "\n", + " history = (history or []) + [{\"role\": \"user\", \"content\": full_content}]\n", + " reply_md, new_avail, preferred, audio_text, dbg_text = run_and_extract(history)\n", + " history += [{\"role\": \"assistant\", \"content\": reply_md}]\n", + "\n", + " all_avail = merge_and_sort(known_avail or [], new_avail or [])\n", + " preferred_final = preferred or preferred_now or \"\"\n", + " audio_bytes = synth_audio(audio_text)\n", + "\n", + " return (\n", + " history, # s_history\n", + " all_avail, # s_available (cumulative)\n", + " preferred_final, # s_preferred\n", + " gr.update(value=fmt_preferred_md(preferred_final)),\n", + " gr.update(value=fmt_available_md(all_avail)),\n", + " gr.update(value=\"\", visible=True), # reply_in: show after first run\n", + " gr.update(value=audio_bytes, visible=True), # audio_out\n", + " gr.update(value=dbg_text), # debug_box\n", + " gr.update(value=\"Find Domains (done)\", interactive=False), # NEW: disable Find\n", + " gr.update(visible=True), # NEW: show Send button\n", + " )\n", + "\n", + "def refine_submit(reply: str,\n", + " history: List[Dict], known_avail: List[str], preferred_now: str):\n", + " # If empty, do nothing (keeps UI state untouched)\n", + " if not (reply or \"\").strip():\n", + " return (\"\", history, known_avail, preferred_now,\n", + " gr.update(), gr.update(), gr.update(), gr.update())\n", + "\n", + " if CLEAR_LOG_ON_RUN:\n", + " clear_log_buffer()\n", + " logger.info(\"Refine submit | user_reply_len=%d\", len(reply))\n", + "\n", + " # Always prepend memory + the user's refinement so the model can iterate properly\n", + " ctx = build_context_msg(known_avail or [], preferred_now or \"\")\n", + " full_content = (ctx + \"\\n\\n\" if ctx else \"\") + reply.strip()\n", + "\n", + " history = (history or []) + [{\"role\": \"user\", \"content\": full_content}]\n", + " reply_md, new_avail, preferred, audio_text, dbg_text = run_and_extract(history)\n", + " history += [{\"role\": \"assistant\", \"content\": reply_md}]\n", + "\n", + " all_avail = merge_and_sort(known_avail or [], new_avail or [])\n", + " preferred_final = preferred or preferred_now or \"\"\n", + " audio_bytes = synth_audio(audio_text)\n", + "\n", + " return (\n", + " \"\", # clear Reply box\n", + " history, # s_history\n", + " all_avail, # s_available (cumulative)\n", + " preferred_final, # s_preferred\n", + " gr.update(value=fmt_preferred_md(preferred_final)),\n", + " gr.update(value=fmt_available_md(all_avail)),\n", + " gr.update(value=audio_bytes, visible=True),\n", + " gr.update(value=dbg_text), # debug_box\n", + " )\n", + "\n", + "def clear_debug():\n", + " clear_log_buffer()\n", + " return gr.update(value=\"\")\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d52ebc02", + "metadata": {}, + "outputs": [], + "source": [ + "with gr.Blocks(title=\"AI Domain Finder (.com only)\") as ui:\n", + " gr.Markdown(\"# AI Domain Finder (.com only)\")\n", + " gr.Markdown(\"Agent proposes .com domains, verifies via RDAP, picks a preferred choice, and explains briefly.\")\n", + "\n", + " # App state\n", + " s_history = gr.State([])\n", + " s_available = gr.State([])\n", + " s_preferred = gr.State(\"\")\n", + "\n", + " with gr.Row():\n", + " with gr.Column(scale=7): # LEFT 70%\n", + " with gr.Group():\n", + " industry_in = gr.Textbox(label=\"Industry\")\n", + " customers_in = gr.Textbox(label=\"Target Customers\")\n", + " desc_in = gr.Textbox(label=\"Description\", lines=3)\n", + " find_btn = gr.Button(\"Find Domains\", variant=\"primary\")\n", + "\n", + " audio_out = gr.Audio(label=\"Audio explanation\", autoplay=True, visible=False)\n", + "\n", + " with gr.Row():\n", + " reply_in = gr.Textbox(\n", + " label=\"Reply\",\n", + " placeholder=\"Chat with the agent to refine the outputs\",\n", + " lines=2,\n", + " visible=False, # hidden for the first input\n", + " )\n", + " send_btn = gr.Button(\"Send\", variant=\"primary\", visible=False)\n", + "\n", + " with gr.Column(scale=3): # RIGHT 30%\n", + " preferred_md = gr.Markdown(fmt_preferred_md(\"\"))\n", + " available_md = gr.Markdown(fmt_available_md([]))\n", + "\n", + " with gr.Accordion(\"Debug log\", open=False):\n", + " debug_box = gr.Textbox(label=\"Log\", value=\"\", lines=16, interactive=False)\n", + " clear_btn = gr.Button(\"Clear log\", size=\"sm\")\n", + "\n", + " # Events\n", + " # Initial run: also disables Find and shows Send\n", + " find_btn.click(\n", + " initial_submit,\n", + " inputs=[industry_in, customers_in, desc_in, s_history, s_available, s_preferred],\n", + " outputs=[\n", + " s_history, s_available, s_preferred,\n", + " preferred_md, available_md,\n", + " reply_in, # visible after first run\n", + " audio_out, # visible after first run\n", + " debug_box,\n", + " find_btn, # NEW: disable + relabel\n", + " send_btn, # NEW: show the Send button\n", + " ],\n", + " )\n", + "\n", + " # Multi-turn submit via Enter in the textbox\n", + " reply_in.submit(\n", + " refine_submit,\n", + " inputs=[reply_in, s_history, s_available, s_preferred],\n", + " outputs=[\n", + " reply_in, s_history, s_available, s_preferred,\n", + " preferred_md, available_md, audio_out, debug_box\n", + " ],\n", + " )\n", + "\n", + " # Multi-turn submit via explicit Send button\n", + " send_btn.click(\n", + " refine_submit,\n", + " inputs=[reply_in, s_history, s_available, s_preferred],\n", + " outputs=[\n", + " reply_in, s_history, s_available, s_preferred,\n", + " preferred_md, available_md, audio_out, debug_box\n", + " ],\n", + " )\n", + "\n", + " clear_btn.click(clear_debug, inputs=[], outputs=[debug_box])\n", + "\n", + "ui.launch(inbrowser=True, show_error=True)\n" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "llm-engineering", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.10" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/week2/community-contributions/hopeogbons/README.md b/week2/community-contributions/hopeogbons/README.md new file mode 100644 index 0000000..953a7e5 --- /dev/null +++ b/week2/community-contributions/hopeogbons/README.md @@ -0,0 +1,355 @@ +# 🏥 RoboCare AI Assistant + +> Born from a real problem at MyWoosah Inc—now solving caregiver matching through AI. + +## 📋 The Story Behind This Project + +While working on a caregiver matching platform for **MyWoosah Inc** in the US, I faced a real challenge: how do you efficiently match families with the right caregivers when everyone has different needs? + +Families would ask things like: + +- _"I need someone for my mom on Monday mornings who speaks Spanish"_ +- _"Can you find elder care in Boston under $30/hour with CPR certification?"_ + +Writing individual SQL queries for every combination of filters was exhausting and error-prone. I knew there had to be a better way. + +That's when I discovered the **Andela LLM Engineering program**. I saw an opportunity to transform this problem into a solution using AI. Instead of rigid queries, what if families could just... talk? And the AI would understand, search, and recommend? + +This project is my answer. It's not just an exercise—it's solving a real problem I encountered in the field. + +## What It Does + +RoboCare helps families find caregivers through natural conversation: + +- 🔍 Searches the database intelligently +- 🎯 Finds the best matches +- 💬 Explains pros/cons in plain English +- 🔊 Speaks the results back to you + +## ✨ Features + +### 🤖 AI-Powered Matching + +- Natural language conversation interface +- Intelligent requirement gathering +- Multi-criteria search optimization +- Personalized recommendations with pros/cons analysis + +### 🔍 Advanced Search Capabilities + +- **Location-based filtering**: City, state, and country +- **Service type matching**: Elder care, child care, companionship, dementia care, hospice support, and more +- **Availability scheduling**: Day and time-based matching +- **Budget optimization**: Maximum hourly rate filtering +- **Language preferences**: Multi-language support +- **Certification requirements**: CPR, CNA, BLS, and specialized certifications +- **Experience filtering**: Minimum years of experience + +### 🎙️ Multi-Modal Interface + +- Text-based chat interface +- Voice response generation (Text-to-Speech) +- Multiple voice options (coral, alloy, echo, fable, onyx, nova, shimmer) +- Clean, modern UI built with Gradio + +### 🛡️ Defensive Architecture + +- Comprehensive error handling +- Token overflow protection +- Tool call validation +- Graceful degradation + +## 🚀 Getting Started + +### Prerequisites + +- Python 3.8+ +- OpenAI API key +- Virtual environment (recommended) + +### Installation + +1. **Clone the repository** + + ```bash + cd week2 + ``` + +2. **Create and activate virtual environment** + + ```bash + python -m venv .venv + source .venv/bin/activate # On Windows: .venv\Scripts\activate + ``` + +3. **Install dependencies** + + ```bash + pip install -r requirements.txt + ``` + +4. **Set up environment variables** + + Create a `.env` file in the project root: + + ```env + OPENAI_API_KEY=your_openai_api_key_here + ``` + +5. **Run the application** + + ```bash + jupyter notebook "week2 EXERCISE.ipynb" + ``` + + Or run all cells sequentially in your Jupyter environment. + +## 📊 Database Schema + +### Tables + +#### `caregivers` + +Primary caregiver information including: + +- Personal details (name, gender) +- Experience level +- Hourly rate and currency +- Location (city, state, country, coordinates) +- Live-in availability + +#### `caregiver_services` + +Care types offered by each caregiver: + +- Elder care +- Child care +- Companionship +- Post-op support +- Special needs +- Respite care +- Dementia care +- Hospice support + +#### `availability` + +Time slots when caregivers are available: + +- Day of week (Mon-Sun) +- Start and end times (24-hour format) + +#### `languages` + +Languages spoken by caregivers + +#### `certifications` + +Professional certifications (CPR, CNA, BLS, etc.) + +#### `traits` + +Personality and professional traits + +## 🔧 Architecture + +### Tool Registry Pattern + +```python +TOOL_REGISTRY = { + "search_caregivers": search_caregivers, + "get_caregiver_profile": get_caregiver_profile, + # ... more tools +} +``` + +All database functions are registered and can be called by the AI dynamically. + +### Search Functions + +#### `search_caregivers()` + +Multi-filter search with parameters: + +- `city`, `state_province`, `country` - Location filters +- `care_type` - Type of care needed +- `min_experience` - Minimum years of experience +- `max_hourly_rate` - Budget constraint +- `live_in` - Live-in caregiver requirement +- `language` - Language preference +- `certification` - Required certification +- `day` - Day of week availability +- `time_between` - Time window availability +- `limit`, `offset` - Pagination + +#### `get_caregiver_profile(caregiver_id)` + +Returns complete profile including: + +- Basic information +- Services offered +- Languages spoken +- Certifications +- Personality traits +- Availability schedule + +## 🎨 UI Components + +### Main Interface + +- **Chat History**: Message-based conversation display +- **Voice Response**: Auto-playing audio output +- **Settings Sidebar**: + - AI Model selector + - Voice selection + - Audio toggle + - Clear conversation button + +### User Experience + +- Professional gradient header +- Collapsible instructions +- Helpful placeholder text +- Custom CSS styling +- Responsive layout + +## 📝 Usage Examples + +### Example 1: Basic Search + +```python +results = search_caregivers( + city="New York", + care_type="elder care", + max_hourly_rate=30.0, + limit=5 +) +``` + +### Example 2: Language Filter + +```python +results = search_caregivers( + care_type="child care", + language="Spanish", + limit=3 +) +``` + +### Example 3: Availability Search + +```python +results = search_caregivers( + day="Mon", + time_between=("09:00", "17:00"), + city="Boston" +) +``` + +### Example 4: Get Full Profile + +```python +profile = get_caregiver_profile(caregiver_id=1) +print(profile['services']) +print(profile['availability']) +``` + +## 🔐 Security & Best Practices + +### Current Implementation + +- ✅ Environment variable management for API keys +- ✅ SQL injection prevention (parameterized queries) +- ✅ Error handling and graceful degradation +- ✅ Input validation through tool schemas + +### Important Disclaimers + +⚠️ **This is a demonstration application** + +- Credentials and background checks are NOT verified +- Families should independently verify all caregiver information +- Not intended for production use without additional security measures + +## 🛠️ Tech Stack + +- **AI/ML**: OpenAI GPT-4o-mini, Text-to-Speech API +- **Database**: SQLite with normalized schema +- **UI Framework**: Gradio +- **Language**: Python 3.8+ +- **Key Libraries**: + - `openai` - OpenAI API client + - `gradio` - Web interface + - `python-dotenv` - Environment management + - `sqlite3` - Database operations + +## 📈 What's Next + +### Immediate Plans + +- [ ] Add speech input (families could call and talk) +- [ ] Connect to actual MyWoosah database +- [ ] Background check API integration +- [ ] Deploy for real users + +### Future Enhancements + +- [ ] Streaming responses for real-time interaction +- [ ] Dynamic model switching +- [ ] User authentication and profiles +- [ ] Review and rating system +- [ ] Payment integration +- [ ] Calendar integration for scheduling + +## 💡 Key Learnings + +Through building this project, I learned: + +1. **Prompt engineering is critical** - Small keyword mismatches = zero results. Mapping "Monday" → "Mon" matters. +2. **Function calling is powerful** - Eliminated the need for custom queries. The AI figures it out. +3. **Defensive programming saves headaches** - Things break. This code expects it and handles it elegantly. +4. **AI makes databases accessible** - Good database design + AI = natural language interface + +## 🌍 The Bigger Picture + +This isn't just about caregiving. The same pattern works for: + +- Healthcare appointment booking +- Legal service matching +- Tutoring and education platforms +- Real estate agent matching +- Any matching problem where natural language beats forms + +**AI doesn't replace good database design—it makes it accessible to everyone.** + +--- + +## 🤝 Contributing + +This project was created as part of the **Andela LLM Engineering Week 2 Exercise**. + +Feedback and contributions are welcome! Feel free to: + +1. Fork the repository +2. Create a feature branch +3. Make your changes +4. Run all cells to test +5. Submit a pull request + +## 🙏 Acknowledgments + +- **MyWoosah Inc** - For the real-world problem that inspired this solution +- **Andela LLM Engineering Program** - Educational framework and guidance +- **OpenAI** - GPT-4o and TTS API +- **Gradio** - Making beautiful UIs accessible + +--- + +
+ +**For MyWoosah Inc and beyond:** This is proof that AI can transform how we connect people with the care they need. + +_Built with ❤️ during Week 2 of the Andela LLM Engineering Program_ + +**RoboOffice Ltd** + +
diff --git a/week2/community-contributions/hopeogbons/care_app.db b/week2/community-contributions/hopeogbons/care_app.db new file mode 100644 index 0000000..93f8fdb Binary files /dev/null and b/week2/community-contributions/hopeogbons/care_app.db differ diff --git a/week2/community-contributions/hopeogbons/week2 EXERCISE.ipynb b/week2/community-contributions/hopeogbons/week2 EXERCISE.ipynb new file mode 100644 index 0000000..6915f24 --- /dev/null +++ b/week2/community-contributions/hopeogbons/week2 EXERCISE.ipynb @@ -0,0 +1,1525 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "d006b2ea-9dfe-49c7-88a9-a5a0775185fd", + "metadata": {}, + "source": [ + "# 🏥 RoboCare AI Assistant\n", + "\n", + "## Why I Built This\n", + "\n", + "While working on a caregiver matching platform for **MyWoosah Inc** in the US, I faced a real challenge: how do you efficiently match families with the right caregivers when everyone has different needs?\n", + "\n", + "Families would ask things like:\n", + "- *\"I need someone for my mom on Monday mornings who speaks Spanish\"*\n", + "- *\"Can you find elder care in Boston under $30/hour with CPR certification?\"*\n", + "\n", + "Writing individual SQL queries for every combination of filters was exhausting and error-prone. I knew there had to be a better way.\n", + "\n", + "That's when I discovered the **Andela LLM Engineering program**. I saw an opportunity to transform this problem into a solution using AI. Instead of rigid queries, what if families could just... talk? And the AI would understand, search, and recommend?\n", + "\n", + "This project is my answer. It's not just an exercise—it's solving a real problem I encountered in the field.\n", + "\n", + "---\n", + "\n", + "## What This Does\n", + "\n", + "RoboCare helps families find caregivers through natural conversation. You tell it what you need, and it:\n", + "- 🔍 Searches the database intelligently\n", + "- 🎯 Finds the best matches\n", + "- 💬 Explains pros/cons in plain English \n", + "- 🔊 Speaks the results back to you\n", + "\n", + "**Tech:** OpenAI GPT-4o + Voice • Gradio UI • SQLite Database • Function Calling\n", + "\n", + "---\n", + "\n", + "**Note:** This is a demonstration. Always verify credentials independently." + ] + }, + { + "cell_type": "markdown", + "id": "4381c40c", + "metadata": {}, + "source": [ + "## Step 1: Libraries\n", + "\n", + "The essentials: OpenAI for the AI brain, Gradio for the interface, SQLite for data storage.\n" + ] + }, + { + "cell_type": "code", + "execution_count": 63, + "id": "185c6841", + "metadata": {}, + "outputs": [], + "source": [ + "# imports\n", + "\n", + "import os\n", + "from dotenv import load_dotenv\n", + "from openai import OpenAI\n", + "import gradio as gr\n", + "import sqlite3\n", + "import sqlite3\n", + "from textwrap import dedent\n", + "from contextlib import contextmanager\n", + "from typing import Optional, List, Dict, Any, Tuple" + ] + }, + { + "cell_type": "markdown", + "id": "2a366c15", + "metadata": {}, + "source": [ + "## Step 2: Setup\n", + "\n", + "Loading API keys securely (never hardcode them!), setting up the OpenAI client, and pointing to our database.\n" + ] + }, + { + "cell_type": "code", + "execution_count": 64, + "id": "0e731b96", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "OpenAI API Key exists and begins sk-proj-\n" + ] + } + ], + "source": [ + "# Initialization\n", + "\n", + "load_dotenv(override=True)\n", + "\n", + "openai_api_key = os.getenv('OPENAI_API_KEY')\n", + "if openai_api_key:\n", + " print(f\"OpenAI API Key exists and begins {openai_api_key[:8]}\")\n", + "else:\n", + " print(\"OpenAI API Key not set\")\n", + " \n", + "MODEL = \"gpt-4o-mini\"\n", + "openai = OpenAI()\n", + "\n", + "DB_PATH = \"care_app.db\"" + ] + }, + { + "cell_type": "markdown", + "id": "686fa27a", + "metadata": {}, + "source": [ + "## Step 3: The Database\n", + "\n", + "20 sample caregivers across major US cities with:\n", + "- Services they offer (elder care, child care, etc.)\n", + "- Languages, certifications, availability\n", + "- Personality traits\n", + "- Realistic pricing and schedules\n", + "\n", + "This mirrors the kind of data MyWoosah Inc would manage—except here, AI does the matching work.\n" + ] + }, + { + "cell_type": "code", + "execution_count": 65, + "id": "965d273d", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Seeded: care_app.db\n" + ] + } + ], + "source": [ + "# Table creation and seeding\n", + "\n", + "SQL = '''\n", + "\n", + " CREATE TABLE IF NOT EXISTS caregivers (\n", + " id INTEGER PRIMARY KEY,\n", + " name TEXT NOT NULL,\n", + " gender TEXT,\n", + " years_experience INTEGER,\n", + " live_in INTEGER, -- 0/1\n", + " hourly_rate REAL,\n", + " currency TEXT,\n", + " city TEXT,\n", + " state_province TEXT,\n", + " country TEXT,\n", + " postal_code TEXT,\n", + " lat REAL,\n", + " lon REAL\n", + " );\n", + "\n", + " CREATE TABLE IF NOT EXISTS caregiver_services (\n", + " caregiver_id INTEGER,\n", + " care_type TEXT,\n", + " FOREIGN KEY (caregiver_id) REFERENCES caregivers(id)\n", + " );\n", + "\n", + " CREATE TABLE IF NOT EXISTS availability (\n", + " caregiver_id INTEGER,\n", + " day TEXT, -- e.g., 'Mon'\n", + " time_start TEXT, -- 'HH:MM'\n", + " time_end TEXT, -- 'HH:MM'\n", + " FOREIGN KEY (caregiver_id) REFERENCES caregivers(id)\n", + " );\n", + "\n", + " CREATE TABLE IF NOT EXISTS languages (\n", + " caregiver_id INTEGER,\n", + " language TEXT,\n", + " FOREIGN KEY (caregiver_id) REFERENCES caregivers(id)\n", + " );\n", + "\n", + " CREATE TABLE IF NOT EXISTS certifications (\n", + " caregiver_id INTEGER,\n", + " cert TEXT,\n", + " FOREIGN KEY (caregiver_id) REFERENCES caregivers(id)\n", + " );\n", + "\n", + " CREATE TABLE IF NOT EXISTS traits (\n", + " caregiver_id INTEGER,\n", + " trait TEXT,\n", + " FOREIGN KEY (caregiver_id) REFERENCES caregivers(id)\n", + " );\n", + "\n", + " ----------------------------------------------------------\n", + "\n", + " -- Clear old data (optional)\n", + "\n", + " DELETE FROM traits;\n", + " DELETE FROM certifications;\n", + " DELETE FROM languages;\n", + " DELETE FROM availability;\n", + " DELETE FROM caregiver_services;\n", + " DELETE FROM caregivers;\n", + "\n", + " -- Seed caregivers (20 examples, all USA)\n", + "\n", + " INSERT INTO caregivers\n", + " (id, name, gender, years_experience, live_in, hourly_rate, currency, city, state_province, country, postal_code, lat, lon)\n", + " VALUES\n", + " (1, 'Grace Williams', 'female', 6, 0, 28, 'USD', 'New York', 'NY', 'USA', '10001', 40.7128, -74.0060),\n", + " (2, 'Miguel Alvarez', 'male', 9, 1, 30, 'USD', 'Los Angeles', 'CA', 'USA', '90012', 34.0522, -118.2437),\n", + " (3, 'Ava Johnson', 'female', 4, 0, 24, 'USD', 'Chicago', 'IL', 'USA', '60601', 41.8781, -87.6298),\n", + " (4, 'Noah Robinson', 'male', 12, 0, 27, 'USD', 'Houston', 'TX', 'USA', '77002', 29.7604, -95.3698),\n", + " (5, 'Sophia Martinez', 'female', 8, 0, 29, 'USD', 'Phoenix', 'AZ', 'USA', '85004', 33.4484, -112.0740),\n", + " (6, 'Daniel Carter', 'male', 10, 1, 31, 'USD', 'Philadelphia', 'PA', 'USA', '19103', 39.9526, -75.1652),\n", + " (7, 'Emily Nguyen', 'female', 7, 0, 26, 'USD', 'San Antonio', 'TX', 'USA', '78205', 29.4241, -98.4936),\n", + " (8, 'Olivia Kim', 'female', 5, 0, 27, 'USD', 'San Diego', 'CA', 'USA', '92101', 32.7157, -117.1611),\n", + " (9, 'James Thompson', 'male', 15, 1, 34, 'USD', 'Dallas', 'TX', 'USA', '75201', 32.7767, -96.7970),\n", + " (10, 'Isabella Garcia', 'female', 3, 0, 22, 'USD', 'San Jose', 'CA', 'USA', '95113', 37.3382, -121.8863),\n", + " (11, 'Ethan Patel', 'male', 11, 1, 33, 'USD', 'Austin', 'TX', 'USA', '78701', 30.2672, -97.7431),\n", + " (12, 'Harper Brooks', 'female', 2, 0, 20, 'USD', 'Jacksonville', 'FL', 'USA', '32202', 30.3322, -81.6557),\n", + " (13, 'Logan White', 'male', 6, 0, 25, 'USD', 'Fort Worth', 'TX', 'USA', '76102', 32.7555, -97.3308),\n", + " (14, 'Amelia Davis', 'female', 9, 0, 28, 'USD', 'Columbus', 'OH', 'USA', '43215', 39.9612, -82.9988),\n", + " (15, 'Charlotte Reed', 'female', 14, 1, 32, 'USD', 'Charlotte', 'NC', 'USA', '28202', 35.2271, -80.8431),\n", + " (16, 'Jackson Lee', 'male', 5, 0, 26, 'USD', 'San Francisco', 'CA', 'USA', '94102', 37.7749, -122.4194),\n", + " (17, 'Avery Chen', 'female', 7, 0, 27, 'USD', 'Seattle', 'WA', 'USA', '98101', 47.6062, -122.3321),\n", + " (18, 'William Turner', 'male', 13, 1, 35, 'USD', 'Denver', 'CO', 'USA', '80202', 39.7392, -104.9903),\n", + " (19, 'Natalie O''Brien', 'female', 16, 0, 36, 'USD', 'Boston', 'MA', 'USA', '02108', 42.3601, -71.0589),\n", + " (20, 'Maya Robinson', 'female', 3, 0, 23, 'USD', 'Atlanta', 'GA', 'USA', '30303', 33.7488, -84.3880);\n", + "\n", + " -- Seed caregiver services\n", + "\n", + " INSERT INTO caregiver_services (caregiver_id, care_type) VALUES\n", + " (1, 'elder care'), (1, 'companionship'),\n", + " (2, 'post-op support'), (2, 'elder care'),\n", + " (3, 'child care'), (3, 'special needs'),\n", + " (4, 'respite care'), (4, 'elder care'),\n", + " (5, 'dementia care'), (5, 'companionship'),\n", + " (6, 'elder care'), (6, 'hospice support'),\n", + " (7, 'child care'), (7, 'respite care'),\n", + " (8, 'post-op support'), (8, 'companionship'),\n", + " (9, 'special needs'), (9, 'elder care'),\n", + " (10, 'child care'), (10, 'companionship'),\n", + " (11, 'dementia care'), (11, 'post-op support'),\n", + " (12, 'child care'), (12, 'special needs'),\n", + " (13, 'respite care'), (13, 'companionship'),\n", + " (14, 'elder care'), (14, 'post-op support'),\n", + " (15, 'hospice support'), (15, 'dementia care'),\n", + " (16, 'elder care'), (16, 'respite care'),\n", + " (17, 'special needs'), (17, 'companionship'),\n", + " (18, 'post-op support'), (18, 'elder care'),\n", + " (19, 'dementia care'), (19, 'hospice support'),\n", + " (20, 'child care'), (20, 'companionship');\n", + "\n", + " -- Seed availability (Mon-Sun samples)\n", + "\n", + " INSERT INTO availability (caregiver_id, day, time_start, time_end) VALUES\n", + " -- 1 Grace (NY): evenings + Sun\n", + " (1, 'Mon', '17:30', '22:00'),\n", + " (1, 'Thu', '17:30', '22:00'),\n", + " (1, 'Sun', '10:00', '16:00'),\n", + " -- 2 Miguel (LA): live-in, long blocks\n", + " (2, 'Tue', '08:00', '20:00'),\n", + " (2, 'Thu', '08:00', '20:00'),\n", + " (2, 'Sat', '09:00', '18:00'),\n", + " -- 3 Ava (CHI): weekdays 09-17\n", + " (3, 'Mon', '09:00', '17:00'),\n", + " (3, 'Wed', '09:00', '17:00'),\n", + " (3, 'Fri', '09:00', '17:00'),\n", + " -- 4 Noah (HOU): Tue-Fri 08-16\n", + " (4, 'Tue', '08:00', '16:00'),\n", + " (4, 'Wed', '08:00', '16:00'),\n", + " (4, 'Thu', '08:00', '16:00'),\n", + " -- 5 Sophia (PHX): Thu-Sun 10-18\n", + " (5, 'Thu', '10:00', '18:00'),\n", + " (5, 'Fri', '10:00', '18:00'),\n", + " (5, 'Sat', '10:00', '18:00'),\n", + " -- 6 Daniel (PHL): Mon-Thu 07-15\n", + " (6, 'Mon', '07:00', '15:00'),\n", + " (6, 'Tue', '07:00', '15:00'),\n", + " (6, 'Thu', '07:00', '15:00'),\n", + " -- 7 Emily (SAT): weekends\n", + " (7, 'Sat', '08:00', '17:00'),\n", + " (7, 'Sun', '09:00', '17:00'),\n", + " (7, 'Fri', '17:00', '21:00'),\n", + " -- 8 Olivia (SD): Mon, Wed evenings\n", + " (8, 'Mon', '16:00', '21:00'),\n", + " (8, 'Wed', '16:00', '21:00'),\n", + " (8, 'Sat', '10:00', '14:00'),\n", + " -- 9 James (DAL): live-in wide\n", + " (9, 'Mon', '07:00', '19:00'),\n", + " (9, 'Wed', '07:00', '19:00'),\n", + " (9, 'Sun', '09:00', '17:00'),\n", + " -- 10 Isabella (SJ): Tue-Thu 12-20\n", + " (10, 'Tue', '12:00', '20:00'),\n", + " (10, 'Wed', '12:00', '20:00'),\n", + " (10, 'Thu', '12:00', '20:00'),\n", + " -- 11 Ethan (ATX): nights\n", + " (11, 'Mon', '18:00', '23:00'),\n", + " (11, 'Tue', '18:00', '23:00'),\n", + " (11, 'Fri', '18:00', '23:00'),\n", + " -- 12 Harper (JAX): school hours\n", + " (12, 'Mon', '09:00', '14:00'),\n", + " (12, 'Wed', '09:00', '14:00'),\n", + " (12, 'Fri', '09:00', '14:00'),\n", + " -- 13 Logan (FTW): Thu-Sat\n", + " (13, 'Thu', '10:00', '18:00'),\n", + " (13, 'Fri', '10:00', '18:00'),\n", + " (13, 'Sat', '10:00', '18:00'),\n", + " -- 14 Amelia (CMH): Mon-Fri 08-16\n", + " (14, 'Mon', '08:00', '16:00'),\n", + " (14, 'Tue', '08:00', '16:00'),\n", + " (14, 'Thu', '08:00', '16:00'),\n", + " -- 15 Charlotte (CLT): live-in style\n", + " (15, 'Tue', '07:00', '19:00'),\n", + " (15, 'Thu', '07:00', '19:00'),\n", + " (15, 'Sat', '08:00', '16:00'),\n", + " -- 16 Jackson (SF): split shifts\n", + " (16, 'Mon', '07:00', '11:00'),\n", + " (16, 'Mon', '17:00', '21:00'),\n", + " (16, 'Sat', '12:00', '18:00'),\n", + " -- 17 Avery (SEA): Tue/Thu + Sun\n", + " (17, 'Tue', '10:00', '18:00'),\n", + " (17, 'Thu', '10:00', '18:00'),\n", + " (17, 'Sun', '11:00', '17:00'),\n", + " -- 18 William (DEN): Mon-Wed 06-14\n", + " (18, 'Mon', '06:00', '14:00'),\n", + " (18, 'Tue', '06:00', '14:00'),\n", + " (18, 'Wed', '06:00', '14:00'),\n", + " -- 19 Natalie (BOS): Tue-Fri 09-17\n", + " (19, 'Tue', '09:00', '17:00'),\n", + " (19, 'Wed', '09:00', '17:00'),\n", + " (19, 'Fri', '09:00', '17:00'),\n", + " -- 20 Maya (ATL): after-school + Sat\n", + " (20, 'Mon', '15:00', '20:00'),\n", + " (20, 'Wed', '15:00', '20:00'),\n", + " (20, 'Sat', '09:00', '15:00');\n", + "\n", + " -- Seed languages\n", + "\n", + " INSERT INTO languages (caregiver_id, language) VALUES\n", + " (1, 'English'), (1, 'Spanish'),\n", + " (2, 'English'), (2, 'Spanish'),\n", + " (3, 'English'),\n", + " (4, 'English'),\n", + " (5, 'English'), (5, 'Spanish'),\n", + " (6, 'English'),\n", + " (7, 'English'), (7, 'Vietnamese'),\n", + " (8, 'English'), (8, 'Korean'),\n", + " (9, 'English'),\n", + " (10,'English'), (10,'Spanish'),\n", + " (11,'English'), (11,'Hindi'),\n", + " (12,'English'),\n", + " (13,'English'),\n", + " (14,'English'), (14,'French'),\n", + " (15,'English'),\n", + " (16,'English'), (16,'Tagalog'),\n", + " (17,'English'), (17,'Mandarin'),\n", + " (18,'English'),\n", + " (19,'English'), (19,'Portuguese'),\n", + " (20,'English'), (20,'ASL');\n", + "\n", + " -- Seed certifications\n", + "\n", + " INSERT INTO certifications (caregiver_id, cert) VALUES\n", + " (1, 'CPR'), (1, 'First Aid'),\n", + " (2, 'CPR'), (2, 'BLS'),\n", + " (3, 'CPR'),\n", + " (4, 'First Aid'), (4, 'CNA'),\n", + " (5, 'CPR'), (5, 'Dementia Care'),\n", + " (6, 'HHA'), (6, 'CPR'),\n", + " (7, 'First Aid'),\n", + " (8, 'CPR'), (8, 'AED'),\n", + " (9, 'CNA'), (9, 'BLS'),\n", + " (10,'First Aid'),\n", + " (11,'CPR'), (11,'Medication Technician'),\n", + " (12,'CPR'),\n", + " (13,'First Aid'),\n", + " (14,'CPR'), (14,'CNA'),\n", + " (15,'Hospice Training'), (15,'CPR'),\n", + " (16,'First Aid'),\n", + " (17,'CPR'), (17,'Special Needs Training'),\n", + " (18,'BLS'), (18,'CPR'),\n", + " (19,'Dementia Care'), (19,'First Aid'),\n", + " (20,'CPR'), (20,'Childcare Safety');\n", + "\n", + " -- Seed traits\n", + "\n", + " INSERT INTO traits (caregiver_id, trait) VALUES\n", + " (1, 'empathetic'), (1, 'detail-oriented'),\n", + " (2, 'patient'), (2, 'communicative'),\n", + " (3, 'cheerful'), (3, 'reliable'),\n", + " (4, 'organized'), (4, 'professional'),\n", + " (5, 'compassionate'), (5, 'trustworthy'),\n", + " (6, 'calm under pressure'), (6, 'punctual'),\n", + " (7, 'adaptable'), (7, 'energetic'),\n", + " (8, 'friendly'), (8, 'respectful'),\n", + " (9, 'thorough'), (9, 'dependable'),\n", + " (10,'gentle'), (10,'attentive'),\n", + " (11,'proactive'), (11,'communicative'),\n", + " (12,'patient'), (12,'kind'),\n", + " (13,'flexible'), (13,'tidy'),\n", + " (14,'reliable'), (14,'punctual'),\n", + " (15,'compassionate'), (15,'detail-oriented'),\n", + " (16,'discreet'), (16,'organized'),\n", + " (17,'empathetic'), (17,'calm under pressure'),\n", + " (18,'professional'), (18,'thorough'),\n", + " (19,'trustworthy'), (19,'proactive'),\n", + " (20,'cheerful'), (20,'attentive');\n", + "\n", + "'''\n", + "\n", + "# Insert the data into the database\n", + "\n", + "sql = dedent(SQL)\n", + "con = sqlite3.connect(DB_PATH)\n", + "con.executescript(sql)\n", + "con.commit()\n", + "con.close()\n", + "print(\"Seeded:\", DB_PATH)\n" + ] + }, + { + "cell_type": "markdown", + "id": "3c0baa64", + "metadata": {}, + "source": [ + "## Step 4: Teaching the AI to Search\n", + "\n", + "Instead of the AI just talking, we teach it to actually *search* the database.\n", + "\n", + "When someone says *\"I need elder care in Boston for Mondays\"*, the AI translates that into:\n", + "```python\n", + "search_caregivers(city=\"Boston\", care_type=\"elder care\", day=\"Mon\")\n", + "```\n", + "\n", + "This schema defines all the filters the AI can use: location, services, budget, language, availability, and more.\n", + "\n", + "**This was the breakthrough.** No more writing custom queries—the AI figures it out.\n" + ] + }, + { + "cell_type": "code", + "execution_count": 66, + "id": "f2af7c67", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[{'type': 'function',\n", + " 'function': {'name': 'search_caregivers',\n", + " 'description': 'Flexible multi-filter caregiver search. Any filter can be omitted. Supports location, service type, experience, pricing, live-in, language, certifications, day/time availability, and pagination.',\n", + " 'parameters': {'type': 'object',\n", + " 'properties': {'city': {'type': 'string',\n", + " 'description': 'City name to filter by (optional).'},\n", + " 'state_province': {'type': 'string',\n", + " 'description': 'State or province to filter by (optional).'},\n", + " 'country': {'type': 'string',\n", + " 'description': 'Country to filter by (optional).'},\n", + " 'care_type': {'type': 'string',\n", + " 'description': \"Service category, e.g., 'elder_care', 'child_care', 'pet_care', 'housekeeping' (optional).\"},\n", + " 'min_experience': {'type': 'integer',\n", + " 'minimum': 0,\n", + " 'description': 'Minimum years of experience (optional).'},\n", + " 'max_hourly_rate': {'type': 'number',\n", + " 'minimum': 0,\n", + " 'description': 'Maximum hourly rate in local currency (optional).'},\n", + " 'live_in': {'type': 'boolean',\n", + " 'description': 'Require live-in caregivers (optional).'},\n", + " 'language': {'type': 'string',\n", + " 'description': \"Required spoken language, e.g., 'English', 'Spanish' (optional).\"},\n", + " 'certification': {'type': 'string',\n", + " 'description': \"Required certification, e.g., 'CPR', 'CNA' (optional).\"},\n", + " 'day': {'type': 'string',\n", + " 'description': \"Day of week to match availability (optional), e.g., 'Monday', 'Tuesday', ... 'Sunday'.\"},\n", + " 'time_between': {'type': 'array',\n", + " 'description': \"Required availability window as ['HH:MM','HH:MM'] in 24h time. Matches caregivers whose availability window fully covers this range.\",\n", + " 'items': {'type': 'string',\n", + " 'pattern': '^\\\\d{2}:\\\\d{2}$',\n", + " 'description': \"Time in 'HH:MM' 24-hour format.\"},\n", + " 'minItems': 2,\n", + " 'maxItems': 2},\n", + " 'limit': {'type': 'integer',\n", + " 'minimum': 1,\n", + " 'maximum': 1000,\n", + " 'default': 50,\n", + " 'description': 'Max number of results to return (default 50).'},\n", + " 'offset': {'type': 'integer',\n", + " 'minimum': 0,\n", + " 'default': 0,\n", + " 'description': 'Number of results to skip for pagination (default 0).'}},\n", + " 'required': [],\n", + " 'additionalProperties': False}}}]" + ] + }, + "execution_count": 66, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Tool definition schema\n", + "\n", + "tools = [{\n", + " \"type\": \"function\",\n", + " \"function\": {\n", + " \"name\": \"search_caregivers\",\n", + " \"description\": (\n", + " \"Flexible multi-filter caregiver search. Any filter can be omitted. \"\n", + " \"Supports location, service type, experience, pricing, live-in, language, \"\n", + " \"certifications, day/time availability, and pagination.\"\n", + " ),\n", + " \"parameters\": {\n", + " \"type\": \"object\",\n", + " \"properties\": {\n", + " \"city\": {\n", + " \"type\": \"string\",\n", + " \"description\": \"City name to filter by (optional).\"\n", + " },\n", + " \"state_province\": {\n", + " \"type\": \"string\",\n", + " \"description\": \"State or province to filter by (optional).\"\n", + " },\n", + " \"country\": {\n", + " \"type\": \"string\",\n", + " \"description\": \"Country to filter by (optional).\"\n", + " },\n", + " \"care_type\": {\n", + " \"type\": \"string\",\n", + " \"description\": (\n", + " \"Service category, e.g., 'elder_care', 'child_care', \"\n", + " \"'pet_care', 'housekeeping' (optional).\"\n", + " )\n", + " },\n", + " \"min_experience\": {\n", + " \"type\": \"integer\",\n", + " \"minimum\": 0,\n", + " \"description\": \"Minimum years of experience (optional).\"\n", + " },\n", + " \"max_hourly_rate\": {\n", + " \"type\": \"number\",\n", + " \"minimum\": 0,\n", + " \"description\": \"Maximum hourly rate in local currency (optional).\"\n", + " },\n", + " \"live_in\": {\n", + " \"type\": \"boolean\",\n", + " \"description\": \"Require live-in caregivers (optional).\"\n", + " },\n", + " \"language\": {\n", + " \"type\": \"string\",\n", + " \"description\": \"Required spoken language, e.g., 'English', 'Spanish' (optional).\"\n", + " },\n", + " \"certification\": {\n", + " \"type\": \"string\",\n", + " \"description\": \"Required certification, e.g., 'CPR', 'CNA' (optional).\"\n", + " },\n", + " \"day\": {\n", + " \"type\": \"string\",\n", + " \"description\": (\n", + " \"Day of week to match availability (optional), e.g., \"\n", + " \"'Monday', 'Tuesday', ... 'Sunday'.\"\n", + " )\n", + " },\n", + " \"time_between\": {\n", + " \"type\": \"array\",\n", + " \"description\": (\n", + " \"Required availability window as ['HH:MM','HH:MM'] in 24h time. \"\n", + " \"Matches caregivers whose availability window fully covers this range.\"\n", + " ),\n", + " \"items\": {\n", + " \"type\": \"string\",\n", + " \"pattern\": \"^\\\\d{2}:\\\\d{2}$\",\n", + " \"description\": \"Time in 'HH:MM' 24-hour format.\"\n", + " },\n", + " \"minItems\": 2,\n", + " \"maxItems\": 2\n", + " },\n", + " \"limit\": {\n", + " \"type\": \"integer\",\n", + " \"minimum\": 1,\n", + " \"maximum\": 1000,\n", + " \"default\": 50,\n", + " \"description\": \"Max number of results to return (default 50).\"\n", + " },\n", + " \"offset\": {\n", + " \"type\": \"integer\",\n", + " \"minimum\": 0,\n", + " \"default\": 0,\n", + " \"description\": \"Number of results to skip for pagination (default 0).\"\n", + " }\n", + " },\n", + " \"required\": [],\n", + " \"additionalProperties\": False\n", + " }\n", + " }\n", + "}]\n", + "\n", + "tools" + ] + }, + { + "cell_type": "markdown", + "id": "76416da2", + "metadata": {}, + "source": [ + "## Step 5: Helper Functions\n", + "\n", + "**Voice:** The AI can speak its responses using OpenAI's text-to-speech.\n", + "\n", + "**Database functions:** All the queries we need—search, get profiles, check availability, etc. These are what the AI calls behind the scenes.\n" + ] + }, + { + "cell_type": "code", + "execution_count": 67, + "id": "2f50cc15", + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "# Convert text to speech using OpenAI's TTS API\n", + "def announcements(message):\n", + " response = openai.audio.speech.create(\n", + " model=\"gpt-4o-mini-tts\",\n", + " voice=\"coral\", # Also, try replacing onyx with alloy or coral\n", + " input=message\n", + " )\n", + " return response.content\n", + "\n", + "# Context manager for database connection\n", + "@contextmanager\n", + "def _conn(dict_rows: bool = True):\n", + " conn = sqlite3.connect(DB_PATH)\n", + " if dict_rows:\n", + " conn.row_factory = _dict_factory\n", + " try:\n", + " yield conn\n", + " conn.commit()\n", + " finally:\n", + " conn.close()\n", + "\n", + "####################\n", + "# Helper functions #\n", + "####################\n", + "\n", + "# Converts SQLite query results from tuples into dictionaries\n", + "def _dict_factory(cursor, row):\n", + " return {col[0]: row[idx] for idx, col in enumerate(cursor.description)}\n", + "# A debug/logging function that prints database tool activity\n", + "def _print(msg: str):\n", + " print(f\"DATABASE TOOL CALLED: {msg}\", flush=True)\n", + "\n", + "################################\n", + "# Caregiver database functions #\n", + "################################\n", + "\n", + "# Counts the number of caregivers in the database\n", + "def get_caregiver_count() -> int:\n", + " _print(\"Counting caregivers\")\n", + " with _conn() as conn:\n", + " cur = conn.cursor()\n", + " cur.execute(\"SELECT COUNT(*) AS n FROM caregivers\")\n", + " return cur.fetchone()[\"n\"]\n", + "\n", + "# Fetches a caregiver's profile by their ID\n", + "def get_caregiver(caregiver_id: int) -> Optional[Dict[str, Any]]:\n", + " _print(f\"Fetching caregiver #{caregiver_id}\")\n", + " with _conn() as conn:\n", + " cur = conn.cursor()\n", + " cur.execute(\"SELECT * FROM caregivers WHERE id = ?\", (caregiver_id,))\n", + " return cur.fetchone()\n", + "\n", + "# Lists caregivers with pagination\n", + "def list_caregivers(limit: int = 20, offset: int = 0) -> List[Dict[str, Any]]:\n", + " _print(f\"Listing caregivers (limit={limit}, offset={offset})\")\n", + " with _conn() as conn:\n", + " cur = conn.cursor()\n", + " cur.execute(\"\"\"\n", + " SELECT * FROM caregivers\n", + " ORDER BY id\n", + " LIMIT ? OFFSET ?\n", + " \"\"\", (limit, offset))\n", + " return cur.fetchall()\n", + "\n", + "# Fetches the services a caregiver offers\n", + "def get_services(caregiver_id: int) -> List[str]:\n", + " _print(f\"Fetching services for caregiver #{caregiver_id}\")\n", + " with _conn() as conn:\n", + " cur = conn.cursor()\n", + " cur.execute(\"\"\"\n", + " SELECT care_type FROM caregiver_services WHERE caregiver_id = ?\n", + " ORDER BY care_type\n", + " \"\"\", (caregiver_id,))\n", + " return [r[\"care_type\"] for r in cur.fetchall()]\n", + "\n", + "# Fetches the languages a caregiver speaks\n", + "def get_languages(caregiver_id: int) -> List[str]:\n", + " _print(f\"Fetching languages for caregiver #{caregiver_id}\")\n", + " with _conn() as conn:\n", + " cur = conn.cursor()\n", + " cur.execute(\"\"\"\n", + " SELECT language FROM languages WHERE caregiver_id = ?\n", + " ORDER BY language\n", + " \"\"\", (caregiver_id,))\n", + " return [r[\"language\"] for r in cur.fetchall()]\n", + "\n", + "# Fetches the certifications a caregiver has\n", + "def get_certifications(caregiver_id: int) -> List[str]:\n", + " _print(f\"Fetching certifications for caregiver #{caregiver_id}\")\n", + " with _conn() as conn:\n", + " cur = conn.cursor()\n", + " cur.execute(\"\"\"\n", + " SELECT cert FROM certifications WHERE caregiver_id = ?\n", + " ORDER BY cert\n", + " \"\"\", (caregiver_id,))\n", + " return [r[\"cert\"] for r in cur.fetchall()]\n", + "\n", + "# Fetches the traits a caregiver has\n", + "def get_traits(caregiver_id: int) -> List[str]:\n", + " _print(f\"Fetching traits for caregiver #{caregiver_id}\")\n", + " with _conn() as conn:\n", + " cur = conn.cursor()\n", + " cur.execute(\"\"\"\n", + " SELECT trait FROM traits WHERE caregiver_id = ?\n", + " ORDER BY trait\n", + " \"\"\", (caregiver_id,))\n", + " return [r[\"trait\"] for r in cur.fetchall()]\n", + "\n", + "# Fetches the availability of a caregiver\n", + "def get_availability(caregiver_id: int) -> List[Dict[str, str]]:\n", + " _print(f\"Fetching availability for caregiver #{caregiver_id}\")\n", + " with _conn() as conn:\n", + " cur = conn.cursor()\n", + " cur.execute(\"\"\"\n", + " SELECT day, time_start, time_end\n", + " FROM availability\n", + " WHERE caregiver_id = ?\n", + " ORDER BY\n", + " CASE day\n", + " WHEN 'Mon' THEN 1 WHEN 'Tue' THEN 2 WHEN 'Wed' THEN 3\n", + " WHEN 'Thu' THEN 4 WHEN 'Fri' THEN 5 WHEN 'Sat' THEN 6\n", + " WHEN 'Sun' THEN 7 ELSE 8\n", + " END, time_start\n", + " \"\"\", (caregiver_id,))\n", + " return cur.fetchall()\n", + "\n", + "# Fetches a caregiver's full profile\n", + "def get_caregiver_profile(caregiver_id: int) -> Optional[Dict[str, Any]]:\n", + " _print(f\"Fetching full profile for caregiver #{caregiver_id}\")\n", + " base = get_caregiver(caregiver_id)\n", + " if not base:\n", + " return None\n", + " base[\"services\"] = get_services(caregiver_id)\n", + " base[\"languages\"] = get_languages(caregiver_id)\n", + " base[\"certifications\"] = get_certifications(caregiver_id)\n", + " base[\"traits\"] = get_traits(caregiver_id)\n", + " base[\"availability\"] = get_availability(caregiver_id)\n", + " return base\n", + "\n", + "###########################################\n", + "# Search caregivers with multiple filters #\n", + "###########################################\n", + "\n", + "def search_caregivers(\n", + " city: Optional[str] = None,\n", + " state_province: Optional[str] = None,\n", + " country: Optional[str] = None,\n", + " care_type: Optional[str] = None,\n", + " min_experience: Optional[int] = None,\n", + " max_hourly_rate: Optional[float] = None,\n", + " live_in: Optional[bool] = None,\n", + " language: Optional[str] = None,\n", + " certification: Optional[str] = None,\n", + " day: Optional[str] = None,\n", + " time_between: Optional[Tuple[str, str]] = None, # ('HH:MM', 'HH:MM')\n", + " limit: int = 50,\n", + " offset: int = 0\n", + ") -> List[Dict[str, Any]]:\n", + " \"\"\"\n", + " Flexible multi-filter search. Any filter can be omitted.\n", + " \"\"\"\n", + " _print(\"Searching caregivers with multiple filters\")\n", + "\n", + " # base + optional joins\n", + " join_clauses = []\n", + " where = [\"1=1\"]\n", + " params: List[Any] = []\n", + "\n", + " if care_type:\n", + " join_clauses.append(\"JOIN caregiver_services s ON s.caregiver_id = c.id\")\n", + " where.append(\"LOWER(s.care_type) = LOWER(?)\")\n", + " params.append(care_type)\n", + "\n", + " if language:\n", + " join_clauses.append(\"JOIN languages l ON l.caregiver_id = c.id\")\n", + " where.append(\"LOWER(l.language) = LOWER(?)\")\n", + " params.append(language)\n", + "\n", + " if certification:\n", + " join_clauses.append(\"JOIN certifications cert ON cert.caregiver_id = c.id\")\n", + " where.append(\"LOWER(cert.cert) = LOWER(?)\")\n", + " params.append(certification)\n", + "\n", + " if day or time_between:\n", + " join_clauses.append(\"JOIN availability a ON a.caregiver_id = c.id\")\n", + " if day:\n", + " where.append(\"a.day = ?\")\n", + " params.append(day)\n", + " if time_between:\n", + " t0, t1 = time_between\n", + " # overlap check: caregiver window [start,end] must include [t0,t1]\n", + " where.append(\"a.time_start <= ? AND a.time_end >= ?\")\n", + " params.extend([t0, t1])\n", + "\n", + " if city:\n", + " where.append(\"LOWER(c.city) = LOWER(?)\")\n", + " params.append(city)\n", + " if state_province:\n", + " where.append(\"LOWER(c.state_province) = LOWER(?)\")\n", + " params.append(state_province)\n", + " if country:\n", + " where.append(\"LOWER(c.country) = LOWER(?)\")\n", + " params.append(country)\n", + " if min_experience is not None:\n", + " where.append(\"c.years_experience >= ?\")\n", + " params.append(min_experience)\n", + " if max_hourly_rate is not None:\n", + " where.append(\"c.hourly_rate <= ?\")\n", + " params.append(max_hourly_rate)\n", + " if live_in is not None:\n", + " where.append(\"c.live_in = ?\")\n", + " params.append(1 if live_in else 0)\n", + "\n", + " sql = f\"\"\"\n", + " SELECT DISTINCT c.*\n", + " FROM caregivers c\n", + " {' '.join(join_clauses)}\n", + " WHERE {' AND '.join(where)}\n", + " ORDER BY c.hourly_rate ASC, c.years_experience DESC, c.id\n", + " LIMIT ? OFFSET ?\n", + " \"\"\"\n", + " params.extend([limit, offset])\n", + "\n", + " with _conn() as conn:\n", + " cur = conn.cursor()\n", + " cur.execute(sql, tuple(params))\n", + " return cur.fetchall()" + ] + }, + { + "cell_type": "markdown", + "id": "6c526d05", + "metadata": {}, + "source": [ + "## Step 6: Quick Test\n", + "\n", + "Before connecting everything to the AI, let's make sure the database works. Run these examples to see sample caregivers and their profiles.\n" + ] + }, + { + "cell_type": "code", + "execution_count": 68, + "id": "98165a21", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "DATABASE TOOL CALLED: Searching caregivers with multiple filters\n", + "Found 1 elder care providers in New York:\n", + "- Grace Williams: $28.0/hr, 6 years experience\n", + "\n", + "============================================================\n", + "\n", + "DATABASE TOOL CALLED: Searching caregivers with multiple filters\n", + "Found 1 Spanish-speaking child care providers:\n", + "- Isabella Garcia in San Jose, CA\n", + "\n", + "============================================================\n", + "\n", + "DATABASE TOOL CALLED: Fetching full profile for caregiver #1\n", + "DATABASE TOOL CALLED: Fetching caregiver #1\n", + "DATABASE TOOL CALLED: Fetching services for caregiver #1\n", + "DATABASE TOOL CALLED: Fetching languages for caregiver #1\n", + "DATABASE TOOL CALLED: Fetching certifications for caregiver #1\n", + "DATABASE TOOL CALLED: Fetching traits for caregiver #1\n", + "DATABASE TOOL CALLED: Fetching availability for caregiver #1\n", + "Detailed profile for Grace Williams:\n", + " Services: companionship, elder care\n", + " Languages: English, Spanish\n", + " Certifications: CPR, First Aid\n", + " Traits: detail-oriented, empathetic\n", + " Availability: 3 time slots\n" + ] + } + ], + "source": [ + "# Example 1: Search for elder care providers in New York\n", + "results = search_caregivers(\n", + " city=\"New York\",\n", + " care_type=\"elder care\",\n", + " max_hourly_rate=30.0,\n", + " limit=5\n", + ")\n", + "\n", + "print(f\"Found {len(results)} elder care providers in New York:\")\n", + "for caregiver in results:\n", + " print(f\"- {caregiver['name']}: ${caregiver['hourly_rate']}/hr, {caregiver['years_experience']} years experience\")\n", + "\n", + "print(\"\\n\" + \"=\"*60 + \"\\n\")\n", + "\n", + "# Example 2: Search for Spanish-speaking child care providers\n", + "results2 = search_caregivers(\n", + " care_type=\"child care\",\n", + " language=\"Spanish\",\n", + " limit=3\n", + ")\n", + "\n", + "print(f\"Found {len(results2)} Spanish-speaking child care providers:\")\n", + "for caregiver in results2:\n", + " print(f\"- {caregiver['name']} in {caregiver['city']}, {caregiver['state_province']}\")\n", + "\n", + "print(\"\\n\" + \"=\"*60 + \"\\n\")\n", + "\n", + "# Example 3: Get detailed profile of a specific caregiver\n", + "if results:\n", + " caregiver_id = results[0]['id']\n", + " profile = get_caregiver_profile(caregiver_id)\n", + " print(f\"Detailed profile for {profile['name']}:\")\n", + " print(f\" Services: {', '.join(profile['services'])}\")\n", + " print(f\" Languages: {', '.join(profile['languages'])}\")\n", + " print(f\" Certifications: {', '.join(profile['certifications'])}\")\n", + " print(f\" Traits: {', '.join(profile['traits'])}\")\n", + " print(f\" Availability: {len(profile['availability'])} time slots\")\n" + ] + }, + { + "cell_type": "markdown", + "id": "abfa81e6", + "metadata": {}, + "source": [ + "## Step 7: The AI's Instructions\n", + "\n", + "Here's where I learned prompt engineering matters *a lot*.\n", + "\n", + "The AI needs to know:\n", + "- What exact keywords to use (\"elder care\" not \"elderly care\", \"Mon\" not \"Monday\")\n", + "- How to map natural language to database values\n", + "- That it should give 2-3 recommendations with pros/cons\n", + "- To remind families to verify credentials independently\n", + "\n", + "**The lesson from MyWoosah:** Small keyword mismatches = zero results. This prompt prevents that.\n" + ] + }, + { + "cell_type": "code", + "execution_count": 69, + "id": "7bbe36e3", + "metadata": {}, + "outputs": [], + "source": [ + "# System prompt\n", + "\n", + "system_prompt = '''\n", + " You are a compassionate Caregiver Assistant that helps families quickly identify the most\n", + " suitable care provider by gathering requirements (care needs, schedule, budget, location,\n", + " language/cultural fit) and matching them to available profiles. Provide 2-3 best-fit options\n", + " with pros/cons, estimated costs, and next steps, and clearly state that credentials/background\n", + " checks are not verified by this sample app and should be confirmed by the family.\n", + "\n", + " CRITICAL: When searching the database, you MUST use these EXACT terms:\n", + "\n", + " CARE TYPES (use exactly as shown):\n", + " - \"elder care\" (for elderly, senior, old age, geriatric care)\n", + " - \"companionship\" (for companion, friendship, social support)\n", + " - \"post-op support\" (for post-surgery, post-operative, recovery care)\n", + " - \"child care\" (for children, kids, babysitting, nanny)\n", + " - \"special needs\" (for disabilities, autism, developmental needs)\n", + " - \"respite care\" (for temporary relief, break for family caregivers)\n", + " - \"dementia care\" (for Alzheimer's, memory care, cognitive decline)\n", + " - \"hospice support\" (for end-of-life, palliative, terminal care)\n", + "\n", + " If a user mentions any variation, map it to the closest match above. If unclear, ask clarifying questions.\n", + "\n", + " DAYS OF WEEK (use exactly as shown):\n", + " - \"Mon\" (for Monday)\n", + " - \"Tue\" (for Tuesday)\n", + " - \"Wed\" (for Wednesday)\n", + " - \"Thu\" (for Thursday)\n", + " - \"Fri\" (for Friday)\n", + " - \"Sat\" (for Saturday)\n", + " - \"Sun\" (for Sunday)\n", + "\n", + " STATES/PROVINCES (use 2-letter codes):\n", + " - Use standard US state abbreviations: \"NY\", \"CA\", \"TX\", \"FL\", \"MA\", etc.\n", + " - Convert full state names to abbreviations before searching\n", + "\n", + " COMMON LANGUAGES:\n", + " - \"English\", \"Spanish\", \"French\", \"Vietnamese\", \"Korean\", \"Hindi\", \"Mandarin\", \"Portuguese\", \"Tagalog\", \"ASL\"\n", + " - Capitalize properly (e.g., user says \"spanish\" → use \"Spanish\")\n", + "\n", + " CERTIFICATIONS:\n", + " - \"CPR\", \"First Aid\", \"CNA\", \"BLS\", \"HHA\", \"AED\", \"Medication Technician\", \"Hospice Training\", \n", + " \"Dementia Care\", \"Special Needs Training\", \"Childcare Safety\"\n", + " - Use exact capitalization and full names\n", + "\n", + " TRAITS:\n", + " - \"empathetic\", \"patient\", \"cheerful\", \"organized\", \"compassionate\", \"calm under pressure\", \n", + " \"adaptable\", \"friendly\", \"thorough\", \"gentle\", \"proactive\", \"flexible\", \"reliable\", \n", + " \"detail-oriented\", \"communicative\", \"energetic\", \"respectful\", \"dependable\", \"attentive\", \n", + " \"kind\", \"tidy\", \"punctual\", \"discreet\", \"professional\", \"trustworthy\"\n", + " - Use lowercase for all traits\n", + "\n", + " SEARCH STRATEGY:\n", + " 1. Listen carefully to user requirements\n", + " 2. Map their natural language to database terms above\n", + " 3. Use search_caregivers() with exact keyword matches\n", + " 4. If no results, suggest alternatives or broader searches\n", + " 5. After getting results, use get_caregiver_profile() for detailed information on top matches\n", + "\n", + " Always confirm your understanding by restating requirements using the exact database terms before searching.\n", + "'''" + ] + }, + { + "cell_type": "markdown", + "id": "0b8ae902", + "metadata": {}, + "source": [ + "## Step 8: Making it Work (and Not Crash)\n", + "\n", + "This is the engine room. When the AI wants to search, this code:\n", + "1. Validates the request\n", + "2. Calls the right database function\n", + "3. Handles errors gracefully (no crashes!)\n", + "4. Limits results to prevent overwhelming the AI\n", + "5. Generates the voice response\n", + "\n", + "**Defensive programming:** I learned the hard way that things break. This code expects problems and handles them elegantly.\n" + ] + }, + { + "cell_type": "code", + "execution_count": 70, + "id": "0d8accbc", + "metadata": {}, + "outputs": [], + "source": [ + "# Function registry: Maps tool names to actual Python functions\n", + "TOOL_REGISTRY = {\n", + " \"search_caregivers\": search_caregivers,\n", + " \"get_caregiver_count\": get_caregiver_count,\n", + " \"get_caregiver\": get_caregiver,\n", + " \"list_caregivers\": list_caregivers,\n", + " \"get_services\": get_services,\n", + " \"get_languages\": get_languages,\n", + " \"get_certifications\": get_certifications,\n", + " \"get_traits\": get_traits,\n", + " \"get_availability\": get_availability,\n", + " \"get_caregiver_profile\": get_caregiver_profile,\n", + "}\n", + "\n", + "def execute_tool_call(tool_call):\n", + " \"\"\"\n", + " Safely execute a single tool call with error handling.\n", + " Returns a properly formatted tool response.\n", + " \"\"\"\n", + " import json\n", + " \n", + " function_name = tool_call.function.name\n", + " \n", + " # Defensive check: Ensure function exists in registry\n", + " if function_name not in TOOL_REGISTRY:\n", + " return {\n", + " \"role\": \"tool\",\n", + " \"tool_call_id\": tool_call.id,\n", + " \"content\": json.dumps({\n", + " \"error\": f\"Unknown function: {function_name}\",\n", + " \"available_functions\": list(TOOL_REGISTRY.keys())\n", + " })\n", + " }\n", + " \n", + " try:\n", + " # Parse arguments\n", + " args = json.loads(tool_call.function.arguments)\n", + " \n", + " # Execute the function\n", + " func = TOOL_REGISTRY[function_name]\n", + " result = func(**args)\n", + " \n", + " # Format response based on result type with limit to prevent token overflow\n", + " if isinstance(result, list):\n", + " content = json.dumps({\n", + " \"count\": len(result),\n", + " \"results\": result[:10] if len(result) > 10 else result,\n", + " \"truncated\": len(result) > 10\n", + " })\n", + " elif isinstance(result, dict):\n", + " content = json.dumps(result)\n", + " elif isinstance(result, (int, float, str)):\n", + " content = json.dumps({\"result\": result})\n", + " else:\n", + " content = str(result)\n", + " \n", + " return {\n", + " \"role\": \"tool\",\n", + " \"tool_call_id\": tool_call.id,\n", + " \"content\": content\n", + " }\n", + " \n", + " except Exception as e:\n", + " # Defensive error handling\n", + " return {\n", + " \"role\": \"tool\",\n", + " \"tool_call_id\": tool_call.id,\n", + " \"content\": json.dumps({\n", + " \"error\": str(e),\n", + " \"function\": function_name,\n", + " \"args\": tool_call.function.arguments\n", + " })\n", + " }\n", + "\n", + "def process_tool_calls(message):\n", + " \"\"\"\n", + " Process all tool calls from the AI response.\n", + " Returns tool responses and extracted metadata.\n", + " \"\"\"\n", + " responses = []\n", + " metadata = {\n", + " \"cities\": set(),\n", + " \"caregiver_ids\": set(),\n", + " \"total_results\": 0\n", + " }\n", + " \n", + " if not message.tool_calls:\n", + " return responses, metadata\n", + " \n", + " for tool_call in message.tool_calls:\n", + " # Execute the tool call\n", + " response = execute_tool_call(tool_call)\n", + " responses.append(response)\n", + " \n", + " # Extract metadata for UI enhancements\n", + " try:\n", + " import json\n", + " content = json.loads(response[\"content\"])\n", + " \n", + " # Extract cities from search results\n", + " if \"results\" in content and isinstance(content[\"results\"], list):\n", + " for item in content[\"results\"]:\n", + " if isinstance(item, dict) and \"city\" in item:\n", + " metadata[\"cities\"].add(item[\"city\"])\n", + " if isinstance(item, dict) and \"id\" in item:\n", + " metadata[\"caregiver_ids\"].add(item[\"id\"])\n", + " \n", + " if \"count\" in content:\n", + " metadata[\"total_results\"] += content[\"count\"]\n", + " \n", + " except:\n", + " pass # Silently ignore metadata extraction errors\n", + " \n", + " return responses, metadata\n", + "\n", + "def generate_city_image(city):\n", + " \"\"\"\n", + " Generate or retrieve a city image (placeholder for future enhancement).\n", + " Could integrate with DALL-E, Unsplash API, or local image database.\n", + " \"\"\"\n", + " # Placeholder - can be enhanced with actual image generation\n", + " return None\n", + "\n", + "def chat(history):\n", + " \"\"\"\n", + " Main chat handler with multi-modal support and defensive error handling.\n", + " Handles conversation flow, tool calls, and response generation.\n", + " \"\"\"\n", + " # Normalize history format\n", + " history = [{\"role\": h[\"role\"], \"content\": h[\"content\"]} for h in history]\n", + " \n", + " # Initialize conversation with system prompt\n", + " messages = [{\"role\": \"system\", \"content\": system_prompt}] + history\n", + " \n", + " # Initialize metadata\n", + " image = None\n", + " selected_city = None\n", + " \n", + " try:\n", + " # Initial API call\n", + " response = openai.chat.completions.create(\n", + " model=MODEL,\n", + " messages=messages,\n", + " tools=tools\n", + " )\n", + " \n", + " # Tool calling loop (with safety limit)\n", + " max_iterations = 5\n", + " iteration = 0\n", + " \n", + " while response.choices[0].finish_reason == \"tool_calls\" and iteration < max_iterations:\n", + " iteration += 1\n", + " message = response.choices[0].message\n", + " \n", + " # Process all tool calls\n", + " tool_responses, metadata = process_tool_calls(message)\n", + " \n", + " # Track city for image generation\n", + " if metadata[\"cities\"]:\n", + " selected_city = list(metadata[\"cities\"])[0]\n", + " \n", + " # Add assistant message and tool responses to conversation\n", + " messages.append(message)\n", + " messages.extend(tool_responses)\n", + " \n", + " # Continue conversation\n", + " response = openai.chat.completions.create(\n", + " model=MODEL,\n", + " messages=messages,\n", + " tools=tools\n", + " )\n", + " \n", + " # Extract final reply\n", + " reply = response.choices[0].message.content\n", + " history.append({\"role\": \"assistant\", \"content\": reply})\n", + " \n", + " # Generate voice response\n", + " voice = announcements(reply)\n", + " \n", + " # Generate city image if applicable\n", + " if selected_city:\n", + " image = generate_city_image(selected_city)\n", + " \n", + " return history, voice, image\n", + " \n", + " except Exception as e:\n", + " # Defensive error handling for the entire chat flow\n", + " error_message = f\"I apologize, but I encountered an error: {str(e)}. Please try again.\"\n", + " history.append({\"role\": \"assistant\", \"content\": error_message})\n", + " return history, None, None" + ] + }, + { + "cell_type": "markdown", + "id": "451ed2e5", + "metadata": {}, + "source": [ + "## Step 9: The Interface\n", + "\n", + "A clean, professional web UI built with Gradio.\n", + "\n", + "Features:\n", + "- Chat interface with conversation history\n", + "- Voice responses that auto-play\n", + "- Settings sidebar (model selection, voice options)\n", + "- Clear instructions for families\n", + "\n", + "**Why Gradio?** At MyWoosah, I needed something non-technical staff could use immediately. Gradio made that possible without weeks of frontend work.\n", + "\n", + "**Run this cell to launch!** 🚀\n" + ] + }, + { + "cell_type": "code", + "execution_count": 71, + "id": "a07e7793-b8f5-44f4-aded-5562f633271a", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "* Running on local URL: http://127.0.0.1:7871\n", + "* To create a public link, set `share=True` in `launch()`.\n" + ] + }, + { + "data": { + "text/html": [ + "
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [] + }, + "execution_count": 71, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import gradio as gr\n", + "\n", + "# Gradio UI Setup\n", + "\n", + "def put_message_in_chatbot(message, history):\n", + " \"\"\"Add user message to chat history\"\"\"\n", + " return \"\", history + [{\"role\": \"user\", \"content\": message}]\n", + "\n", + "# Custom CSS for better styling\n", + "custom_css = \"\"\"\n", + "#chatbot {\n", + " border-radius: 10px;\n", + " box-shadow: 0 2px 8px rgba(0,0,0,0.1);\n", + "}\n", + "#message_box {\n", + " border-radius: 8px;\n", + "}\n", + ".header {\n", + " text-align: center;\n", + " padding: 20px;\n", + " background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);\n", + " color: white;\n", + " border-radius: 10px;\n", + " margin-bottom: 20px;\n", + "}\n", + "\"\"\"\n", + "\n", + "with gr.Blocks(title=\"CareGiver AI Assistant\", css=custom_css, theme=gr.themes.Soft()) as ui:\n", + " \n", + " # Header\n", + " gr.Markdown(\"\"\"\n", + "
\n", + "

🏥 RoboCare AI Assistant

\n", + "

Find the perfect caregiver for your loved ones

\n", + "
\n", + " \"\"\")\n", + " \n", + " # Instructions\n", + " with gr.Accordion(\"ℹ️ Click here to learn more on how to use this AI\", open=False):\n", + " gr.Markdown(\"\"\"\n", + " **Tell me what you need:**\n", + " - Type of care (elder care, child care, companionship, etc.)\n", + " - Location (city, state)\n", + " - Schedule requirements (days/times)\n", + " - Budget constraints\n", + " - Language or certification needs\n", + " \n", + " **Example:** \"I need an elder care provider in Boston for Monday mornings who speaks Spanish and has CPR certification.\"\n", + " \n", + " ⚠️ **Note:** This is a demo app. Always verify credentials and conduct background checks independently.\n", + " \"\"\")\n", + " \n", + " # Main chat interface\n", + " with gr.Row():\n", + " with gr.Column(scale=2):\n", + " chatbot = gr.Chatbot(\n", + " height=500, \n", + " type=\"messages\",\n", + " elem_id=\"chatbot\",\n", + " label=\"Chat History\",\n", + " avatar_images=(None, \"🤖\")\n", + " )\n", + " \n", + " # Audio output\n", + " audio_output = gr.Audio(\n", + " label=\"Voice Response\",\n", + " autoplay=True,\n", + " visible=True,\n", + " interactive=False\n", + " )\n", + " \n", + " # Settings sidebar\n", + " with gr.Column(scale=1):\n", + " gr.Markdown(\"### ⚙️ Settings\")\n", + " \n", + " # Model selector (for future enhancement)\n", + " model_select = gr.Dropdown(\n", + " choices=[\"gpt-4o-mini\", \"gpt-4o\", \"gpt-4-turbo\"],\n", + " value=\"gpt-4o-mini\",\n", + " label=\"AI Model\",\n", + " interactive=True\n", + " )\n", + " \n", + " # Voice selector\n", + " voice_select = gr.Dropdown(\n", + " choices=[\"coral\", \"alloy\", \"echo\", \"fable\", \"onyx\", \"nova\", \"shimmer\"],\n", + " value=\"coral\",\n", + " label=\"Voice\",\n", + " interactive=True\n", + " )\n", + " \n", + " # Audio toggle\n", + " audio_enabled = gr.Checkbox(\n", + " label=\"Enable Voice Responses\",\n", + " value=True\n", + " )\n", + " \n", + " # Clear button\n", + " clear_btn = gr.Button(\"🗑️ Clear Conversation\", variant=\"secondary\")\n", + " \n", + " # Input section\n", + " with gr.Row():\n", + " message = gr.Textbox(\n", + " label=\"Your Message\",\n", + " placeholder=\"Type your question here... (e.g., 'I need elder care in Boston')\",\n", + " lines=2,\n", + " elem_id=\"message_box\",\n", + " scale=4\n", + " )\n", + " send_btn = gr.Button(\"Send\", variant=\"primary\", scale=1)\n", + " \n", + " # Event handlers\n", + " def chat_wrapper(history):\n", + " \"\"\"Wrapper to handle chat and extract only needed outputs\"\"\"\n", + " history_out, voice, image = chat(history)\n", + " return history_out, voice\n", + " \n", + " # Submit on enter or button click\n", + " submit_event = message.submit(\n", + " put_message_in_chatbot,\n", + " inputs=[message, chatbot],\n", + " outputs=[message, chatbot]\n", + " ).then(\n", + " chat_wrapper,\n", + " inputs=chatbot,\n", + " outputs=[chatbot, audio_output]\n", + " )\n", + " \n", + " send_btn.click(\n", + " put_message_in_chatbot,\n", + " inputs=[message, chatbot],\n", + " outputs=[message, chatbot]\n", + " ).then(\n", + " chat_wrapper,\n", + " inputs=chatbot,\n", + " outputs=[chatbot, audio_output]\n", + " )\n", + " \n", + " # Clear conversation\n", + " clear_btn.click(\n", + " lambda: ([], None),\n", + " outputs=[chatbot, audio_output]\n", + " )\n", + " \n", + " # Footer\n", + " gr.Markdown(\"\"\"\n", + " ---\n", + "
\n", + " Powered by OpenAI & Gradio | Built by RoboOffice Ltd\n", + "
\n", + " \"\"\")\n", + "\n", + "# Launch with better configuration\n", + "ui.launch(\n", + " inbrowser=True,\n", + " share=False,\n", + " show_error=True,\n", + " quiet=False\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "97d87d95", + "metadata": {}, + "source": [ + "---\n", + "\n", + "## Reflections\n", + "\n", + "This project started from frustration: *\"There has to be a better way to match families with caregivers.\"*\n", + "\n", + "Through the Andela program, I learned that AI + thoughtful engineering = solutions to real problems.\n", + "\n", + "### What Worked:\n", + "- **Function calling** eliminated the need for custom queries\n", + "- **Prompt engineering** prevented keyword mismatches\n", + "- **Defensive coding** made it robust\n", + "- **Gradio** made it accessible\n", + "\n", + "### What I'd Do Next:\n", + "- Add speech input (families could call and talk)\n", + "- Connect to actual MyWoosah database\n", + "- Add background check API integration\n", + "- Deploy for real users\n", + "\n", + "### The Bigger Picture:\n", + "\n", + "This isn't just about caregiving. The same pattern works for:\n", + "- Healthcare appointments\n", + "- Legal services\n", + "- Tutoring platforms\n", + "- Any matching problem where natural language beats forms\n", + "\n", + "AI doesn't replace good database design—it makes it accessible to everyone.\n", + "\n", + "---\n", + "\n", + "**For MyWoosah Inc and beyond:** This is proof that AI can transform how we connect people with the care they need.\n", + "\n", + "*Built during Week 2 of the Andela LLM Engineering Program*\n" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": ".venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.12" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/week2/community-contributions/kwabena/week2_solution_.ipynb b/week2/community-contributions/kwabena/week2_solution_.ipynb new file mode 100644 index 0000000..9b1f22e --- /dev/null +++ b/week2/community-contributions/kwabena/week2_solution_.ipynb @@ -0,0 +1,173 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "fd1cdd6e", + "metadata": {}, + "source": [ + "## Week 2 - Full Prototype for Technical Questions Answerer" + ] + }, + { + "cell_type": "markdown", + "id": "70db9a0b", + "metadata": {}, + "source": [ + " This notebook will implement a Gradio UI, streaming, use of the system prompt to add expertise, and the ability to switch between models." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "df46689d", + "metadata": {}, + "outputs": [], + "source": [ + "# imports\n", + "import os\n", + "import json\n", + "from dotenv import load_dotenv\n", + "from openai import OpenAI\n", + "import gradio as gr\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c7416a2a", + "metadata": {}, + "outputs": [], + "source": [ + "# Initialization\n", + "load_dotenv(override=True)\n", + "\n", + "openai_api_key = os.getenv('OPENAI_API_KEY')\n", + "if openai_api_key:\n", + " print(f\"OpenAI API Key exists and begins {openai_api_key[:8]}\")\n", + "else:\n", + " print(\"OpenAI API Key not set\")\n", + " \n", + "MODEL = \"gpt-4.1-mini\"\n", + "openai = OpenAI()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "86966749", + "metadata": {}, + "outputs": [], + "source": [ + "system_message = \"\"\"\n", + "You are an expert technical question answerer specializing in data science, programming, \n", + "and software engineering. Your goal is to provide clear, accurate, and practical answers \n", + "to technical questions.\n", + "\n", + "When answering:\n", + "- Break down complex concepts into understandable explanations\n", + "- Provide code examples when relevant, with comments explaining key parts\n", + "- Mention common pitfalls or best practices\n", + "- If a question is ambiguous, state your assumptions or ask for clarification\n", + "- For debugging questions, explain both the fix and why the error occurred\n", + "- Cite specific documentation or resources when helpful\n", + "\n", + "Always prioritize accuracy and clarity over speed. If you're unsure about something, \n", + "acknowledge the uncertainty rather than guessing.\n", + "\"\"\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d34e5b81", + "metadata": {}, + "outputs": [], + "source": [ + "# Streaming chat funcion\n", + "def chat(model, history):\n", + " messages = [{\"role\": \"system\", \"content\": system_message}]\n", + " for h in history:\n", + " messages.append({\"role\": h[\"role\"], \"content\": h[\"content\"]})\n", + "\n", + " stream = openai.chat.completions.create(\n", + " model=model, \n", + " messages=messages,\n", + " stream=True\n", + " )\n", + "\n", + " response = \"\"\n", + " for chunk in stream:\n", + " if chunk.choices[0].delta.content is not None:\n", + " response += chunk.choices[0].delta.content\n", + " yield history + [{\"role\": \"assistant\", \"content\": response}]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "32350869", + "metadata": {}, + "outputs": [], + "source": [ + "#Gradio Interface\n", + "with gr.Blocks() as ui:\n", + " with gr.Row():\n", + " chatbot = gr.Chatbot(height=500, type=\"messages\")\n", + " with gr.Row():\n", + " message = gr.Textbox(label=\"Chat with AI Assistant: \")\n", + " model_dropdown = gr.Dropdown(\n", + " choices=[\"gpt-4.1-mini\",\"gpt-4o-mini\", \"gpt-4o\", \"gpt-4-turbo\"], \n", + " value=\"gpt-4.1-mini\", \n", + " label=\"Select Model\"\n", + " ) \n", + "\n", + " def handle_submit(user_message, chat_history):\n", + " # Add user message to history\n", + " chat_history = chat_history + [{\"role\": \"user\", \"content\": user_message}]\n", + " return \"\", chat_history\n", + "\n", + " message.submit(\n", + " handle_submit, \n", + " inputs=[message, chatbot], \n", + " outputs=[message, chatbot]\n", + " ).then(\n", + " chat, \n", + " inputs=[model_dropdown, chatbot],\n", + " outputs=[chatbot]\n", + " )\n", + "\n", + "ui.launch(inbrowser=True)" + ] + }, + { + "cell_type": "markdown", + "id": "cf2b29e1", + "metadata": {}, + "source": [ + "### Concluding Remarks\n", + "In this exercise, we successfully built a working AI chatbot with Gradio that includes streaming responses and the ability to switch between different models. The implementation demonstrates how to create an interactive interface for LLM applications." + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": ".venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.12" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/week2/community-contributions/week2_exercise_solution-Stephen.ipynb b/week2/community-contributions/week2_exercise_solution-Stephen.ipynb new file mode 100644 index 0000000..21de7d8 --- /dev/null +++ b/week2/community-contributions/week2_exercise_solution-Stephen.ipynb @@ -0,0 +1,296 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "d006b2ea-9dfe-49c7-88a9-a5a0775185fd", + "metadata": {}, + "source": [ + "# End of week 2 Exercise - Bookstore Assistant\n", + "\n", + "Now use everything you've learned from Week 2 to build a full prototype for the technical question/answerer you built in Week 1 Exercise.\n", + "\n", + "This should include a Gradio UI, streaming, use of the system prompt to add expertise, and the ability to switch between models. Bonus points if you can demonstrate use of a tool!\n", + "\n", + "If you feel bold, see if you can add audio input so you can talk to it, and have it respond with audio. ChatGPT or Claude can help you, or email me if you have questions.\n", + "\n", + "I will publish a full solution here soon - unless someone beats me to it...\n", + "\n", + "There are so many commercial applications for this, from a language tutor, to a company onboarding solution, to a companion AI to a course (like this one!) I can't wait to see your results." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "a07e7793-b8f5-44f4-aded-5562f633271a", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "OpenAI API Key exists and begins sk-proj-\n", + "Google API Key exists and begins AIzaSyCL\n" + ] + } + ], + "source": [ + "import os\n", + "import json\n", + "from dotenv import load_dotenv\n", + "from openai import OpenAI\n", + "import gradio as gr\n", + "\n", + "load_dotenv(override=True)\n", + "\n", + "openai_api_key = os.getenv('OPENAI_API_KEY')\n", + "google_api_key = os.getenv('GOOGLE_API_KEY')\n", + "\n", + "if openai_api_key:\n", + " print(f\"OpenAI API Key exists and begins {openai_api_key[:8]}\")\n", + "else:\n", + " print(\"OpenAI API Key not set\")\n", + "\n", + "if google_api_key:\n", + " print(f\"Google API Key exists and begins {google_api_key[:8]}\")\n", + "else:\n", + " print(\"Google API Key not set\")\n", + " \n", + "MODEL_GPT = \"gpt-4.1-mini\"\n", + "MODEL_GEMINI = \"gemini-2.5-pro\"\n", + "\n", + "\n", + "openai = OpenAI()\n", + "\n", + "gemini_url = \"https://generativelanguage.googleapis.com/v1beta/openai/\"\n", + "gemini = OpenAI(api_key=google_api_key, base_url=gemini_url)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0a3aa8bf", + "metadata": {}, + "outputs": [], + "source": [ + "# Gradio UI, streaming, use of the system prompt to add expertise, and the ability to switch between models\n", + "\n", + "system_message= \"\"\"\n", + " You are an assistant in a software engineering bookstore that analyzes the content of technical books and generates concise, informative summaries for readers.\n", + " Your goal is to help customers quickly understand what each book covers, its practical value, and who would benefit most from reading it.\n", + " Respond in markdown without code blocks.\n", + " Each summary should include:\n", + " Overview: The book’s main topic, scope, and focus area (e.g., software architecture, DevOps, system design).\n", + " Key Insights: The most important lessons, principles, or methodologies discussed.\n", + " Recommended For: The type of reader who would benefit most (e.g., junior developers, engineering managers, backend specialists).\n", + " Related Reads: Suggest one or two similar or complementary titles available in the store.\n", + " Maintain a professional and knowledgeable tone that reflects expertise in software engineering literature. \n", + "\"\"\"\n", + "\n", + "def stream_gpt(prompt):\n", + " messages = [\n", + " {\"role\": \"system\", \"content\": system_message},\n", + " {\"role\": \"user\", \"content\": prompt}\n", + " ]\n", + " stream = openai.chat.completions.create(\n", + " model=MODEL_GPT,\n", + " messages=messages,\n", + " stream=True\n", + " )\n", + " result = \"\"\n", + " for chunk in stream:\n", + " result += chunk.choices[0].delta.content or \"\"\n", + " yield result\n", + "\n", + "def stream_gemini(prompt):\n", + " messages = [\n", + " {\"role\": \"system\", \"content\": system_message},\n", + " {\"role\": \"user\", \"content\": prompt}\n", + " ]\n", + " stream = openai.chat.completions.create(\n", + " model=MODEL_GEMINI,\n", + " messages=messages,\n", + " stream=True\n", + " )\n", + " result = \"\"\n", + " for chunk in stream:\n", + " result += chunk.choices[0].delta.content or \"\"\n", + " yield result\n", + "\n", + "def stream_model(prompt, model):\n", + " if model==\"GPT\":\n", + " result = stream_gpt(prompt)\n", + " elif model==\"Gemini\":\n", + " result = stream_gemini(prompt)\n", + " else:\n", + " raise ValueError(\"Unknown model\")\n", + " yield from result\n", + "\n", + "\n", + "message_input = gr.Textbox(label=\"Your message:\", info=\"Enter a software engineering book title for the LLM\", lines=4)\n", + "model_selector = gr.Dropdown([\"GPT\", \"Gemini\"], label=\"Select model\", value=\"GPT\")\n", + "message_output = gr.Markdown(label=\"Response:\")\n", + "\n", + "view = gr.Interface(\n", + " fn=stream_model,\n", + " title=\"Bookstore Assistant\", \n", + " inputs=[message_input, model_selector], \n", + " outputs=[message_output], \n", + " examples=[\n", + " [\"Explain Clean Code by Robert C. Martin\", \"GPT\"],\n", + " [\"Explain Clean Code by Robert C. Martin\", \"Gemini\"]\n", + " ], \n", + " flagging_mode=\"never\"\n", + " )\n", + "view.launch()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a4d7980c", + "metadata": {}, + "outputs": [], + "source": [ + "import sqlite3\n", + "\n", + "DB = \"books.db\"\n", + "\n", + "with sqlite3.connect(DB) as conn:\n", + " cursor = conn.cursor()\n", + " cursor.execute('CREATE TABLE IF NOT EXISTS prices (title TEXT PRIMARY KEY, price REAL)')\n", + " conn.commit()\n", + "\n", + "def get_book_price(title):\n", + " print(f\"DATABASE TOOL CALLED: Getting price for {title}\", flush=True)\n", + " with sqlite3.connect(DB) as conn:\n", + " cursor = conn.cursor()\n", + " cursor.execute('SELECT price FROM prices WHERE title = ?', (title.lower(),))\n", + " result = cursor.fetchone()\n", + " return f\"Book -> {title} price is ${result[0]}\" if result else \"No price data available for this title\"\n", + "\n", + "def set_book_price(title, price):\n", + " with sqlite3.connect(DB) as conn:\n", + " cursor = conn.cursor()\n", + " cursor.execute('INSERT INTO prices (title, price) VALUES (?, ?) ON CONFLICT(title) DO UPDATE SET price = ?', (title.lower(), price, price))\n", + " conn.commit()\n", + "\n", + "book_prices = {\"Clean code\":20, \"Clean architecture\": 30, \"System design\": 40, \"Design patterns\": 50}\n", + "for title, price in book_prices.items():\n", + " set_book_price(title, price)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "86741761", + "metadata": {}, + "outputs": [], + "source": [ + "# use of a tool\n", + "MODEL = \"gpt-4.1-mini\"\n", + "\n", + "system_message = \"\"\"\n", + "You are a helpful assistant in a software engineering bookstore BookEye. \n", + "Give short, courteous answers, no more than 1 sentence.\n", + "Always be accurate. If you don't know the answer, say so.\n", + "\"\"\"\n", + "\n", + "price_function = {\n", + " \"name\": \"get_book_price\",\n", + " \"description\": \"Get the price of a book.\",\n", + " \"parameters\": {\n", + " \"type\": \"object\",\n", + " \"properties\": {\n", + " \"book_title\": {\n", + " \"type\": \"string\",\n", + " \"description\": \"The title of the book that the customer wants to buy\",\n", + " },\n", + " },\n", + " \"required\": [\"book_title\"],\n", + " \"additionalProperties\": False\n", + " }\n", + "}\n", + "tools = [{\"type\": \"function\", \"function\": price_function}]\n", + "\n", + "\n", + "def talker(message):\n", + " response = openai.audio.speech.create(\n", + " model=\"gpt-4o-mini-tts\",\n", + " voice=\"coral\",\n", + " input=message\n", + " )\n", + " return response.content\n", + "\n", + "def handle_tool_calls(message):\n", + " responses = []\n", + " for tool_call in message.tool_calls:\n", + " if tool_call.function.name == \"get_book_price\":\n", + " arguments = json.loads(tool_call.function.arguments)\n", + " title = arguments.get('book_title')\n", + " price_details = get_book_price(title)\n", + " responses.append({\n", + " \"role\": \"tool\",\n", + " \"content\": price_details,\n", + " \"tool_call_id\": tool_call.id\n", + " })\n", + " return responses\n", + "\n", + "def chat(history):\n", + " history = [{\"role\":h[\"role\"], \"content\":h[\"content\"]} for h in history]\n", + " messages = [{\"role\": \"system\", \"content\": system_message}] + history\n", + " response = openai.chat.completions.create(model=MODEL, messages=messages, tools=tools)\n", + "\n", + " while response.choices[0].finish_reason==\"tool_calls\":\n", + " message = response.choices[0].message\n", + " responses = handle_tool_calls(message)\n", + " messages.append(message)\n", + " messages.extend(responses)\n", + " response = openai.chat.completions.create(model=MODEL, messages=messages, tools=tools)\n", + "\n", + " reply = response.choices[0].message.content\n", + " history += [{\"role\":\"assistant\", \"content\":reply}]\n", + "\n", + " voice = talker(reply)\n", + " \n", + " return history, voice\n", + "\n", + "def put_message_in_chatbot(message, history):\n", + " return \"\", history + [{\"role\":\"user\", \"content\":message}]\n", + "with gr.Blocks() as ui:\n", + " with gr.Row():\n", + " chatbot = gr.Chatbot(height=300, type=\"messages\")\n", + " audio_output = gr.Audio(autoplay=True)\n", + " \n", + " with gr.Row():\n", + " message = gr.Textbox(label=\"Chat with our AI Assistant:\")\n", + "\n", + " message.submit(put_message_in_chatbot, inputs=[message, chatbot], outputs=[message, chatbot]).then(\n", + " chat, inputs=chatbot, outputs=[chatbot, audio_output]\n", + " )\n", + "\n", + "ui.launch(inbrowser=True, auth=(\"ted\", \"mowsb\"))" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": ".venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.12" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/week3/community-contributions/week3_Exercise_survey_Dataset_Generation.ipynb b/week3/community-contributions/week3_Exercise_survey_Dataset_Generation.ipynb new file mode 100644 index 0000000..a4474af --- /dev/null +++ b/week3/community-contributions/week3_Exercise_survey_Dataset_Generation.ipynb @@ -0,0 +1,906 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "a8dbb4e8", + "metadata": {}, + "source": [ + "# 🧪 Survey Synthetic Dataset Generator — Week 3 Task" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8d86f629", + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "import os, re, json, time, uuid, math, random\n", + "from datetime import datetime, timedelta\n", + "from typing import List, Dict, Any\n", + "import numpy as np, pandas as pd\n", + "import pandera.pandas as pa\n", + "random.seed(7); np.random.seed(7)\n", + "print(\"✅ Base libraries ready. Pandera available:\", pa is not None)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f196ae73", + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "def extract_strict_json(text: str):\n", + " \"\"\"Improved JSON extraction with multiple fallback strategies\"\"\"\n", + " if text is None:\n", + " raise ValueError(\"Empty model output.\")\n", + " \n", + " t = text.strip()\n", + " \n", + " # Strategy 1: Direct JSON parsing\n", + " try:\n", + " obj = json.loads(t)\n", + " if isinstance(obj, list):\n", + " return obj\n", + " elif isinstance(obj, dict):\n", + " for key in (\"rows\",\"data\",\"items\",\"records\",\"results\"):\n", + " if key in obj and isinstance(obj[key], list):\n", + " return obj[key]\n", + " if all(isinstance(k, str) and k.isdigit() for k in obj.keys()):\n", + " return [obj[k] for k in sorted(obj.keys(), key=int)]\n", + " except json.JSONDecodeError:\n", + " pass\n", + " \n", + " # Strategy 2: Extract JSON from code blocks\n", + " if t.startswith(\"```\"):\n", + " t = re.sub(r\"^```(?:json)?\\s*|\\s*```$\", \"\", t, flags=re.IGNORECASE|re.MULTILINE).strip()\n", + " \n", + " # Strategy 3: Find JSON array in text\n", + " start, end = t.find('['), t.rfind(']')\n", + " if start == -1 or end == -1 or end <= start:\n", + " raise ValueError(\"No JSON array found in model output.\")\n", + " \n", + " t = t[start:end+1]\n", + " \n", + " # Strategy 4: Fix common JSON issues\n", + " t = re.sub(r\",\\s*([\\]}])\", r\"\\1\", t) # Remove trailing commas\n", + " t = re.sub(r\"\\bNaN\\b|\\bInfinity\\b|\\b-Infinity\\b\", \"null\", t) # Replace NaN/Infinity\n", + " t = t.replace(\"\\u00a0\", \" \").replace(\"\\u200b\", \"\") # Remove invisible characters\n", + " \n", + " try:\n", + " return json.loads(t)\n", + " except json.JSONDecodeError as e:\n", + " raise ValueError(f\"Could not parse JSON: {str(e)}. Text: {t[:200]}...\")\n" + ] + }, + { + "cell_type": "markdown", + "id": "3670fa0d", + "metadata": {}, + "source": [ + "## 1) Configuration" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d16bd03a", + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "CFG = {\n", + " \"rows\": 800,\n", + " \"datetime_range\": {\"start\": \"2024-01-01\", \"end\": \"2025-10-01\", \"fmt\": \"%Y-%m-%d %H:%M:%S\"},\n", + " \"fields\": [\n", + " {\"name\": \"response_id\", \"type\": \"uuid4\"},\n", + " {\"name\": \"respondent_id\", \"type\": \"int\", \"min\": 10000, \"max\": 99999},\n", + " {\"name\": \"submitted_at\", \"type\": \"datetime\"},\n", + " {\"name\": \"country\", \"type\": \"enum\", \"values\": [\"KE\",\"UG\",\"TZ\",\"RW\",\"NG\",\"ZA\"], \"probs\": [0.50,0.10,0.12,0.05,0.15,0.08]},\n", + " {\"name\": \"language\", \"type\": \"enum\", \"values\": [\"en\",\"sw\"], \"probs\": [0.85,0.15]},\n", + " {\"name\": \"device\", \"type\": \"enum\", \"values\": [\"android\",\"ios\",\"web\"], \"probs\": [0.60,0.25,0.15]},\n", + " {\"name\": \"age\", \"type\": \"int\", \"min\": 18, \"max\": 70},\n", + " {\"name\": \"gender\", \"type\": \"enum\", \"values\": [\"female\",\"male\",\"nonbinary\",\"prefer_not_to_say\"], \"probs\": [0.49,0.49,0.01,0.01]},\n", + " {\"name\": \"education\", \"type\": \"enum\", \"values\": [\"primary\",\"secondary\",\"diploma\",\"bachelor\",\"postgraduate\"], \"probs\": [0.08,0.32,0.18,0.30,0.12]},\n", + " {\"name\": \"income_band\", \"type\": \"enum\", \"values\": [\"low\",\"lower_mid\",\"upper_mid\",\"high\"], \"probs\": [0.28,0.42,0.23,0.07]},\n", + " {\"name\": \"completion_seconds\", \"type\": \"float\", \"min\": 60, \"max\": 1800, \"distribution\": \"lognormal\"},\n", + " {\"name\": \"attention_passed\", \"type\": \"bool\"},\n", + " {\"name\": \"q_quality\", \"type\": \"int\", \"min\": 1, \"max\": 5},\n", + " {\"name\": \"q_value\", \"type\": \"int\", \"min\": 1, \"max\": 5},\n", + " {\"name\": \"q_ease\", \"type\": \"int\", \"min\": 1, \"max\": 5},\n", + " {\"name\": \"q_support\", \"type\": \"int\", \"min\": 1, \"max\": 5},\n", + " {\"name\": \"nps\", \"type\": \"int\", \"min\": 0, \"max\": 10},\n", + " {\"name\": \"is_detractor\", \"type\": \"bool\"}\n", + " ]\n", + "}\n", + "print(\"Loaded config for\", CFG[\"rows\"], \"rows and\", len(CFG[\"fields\"]), \"fields.\")\n" + ] + }, + { + "cell_type": "markdown", + "id": "7da1f429", + "metadata": {}, + "source": [ + "## 2) Helpers" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d2f5fdff", + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "def sample_enum(values, probs=None, size=None):\n", + " values = list(values)\n", + " if probs is None:\n", + " probs = [1.0 / len(values)] * len(values)\n", + " return np.random.choice(values, p=probs, size=size)\n", + "\n", + "def sample_numeric(field_cfg, size=1):\n", + " t = field_cfg[\"type\"]\n", + " if t == \"int\":\n", + " lo, hi = int(field_cfg[\"min\"]), int(field_cfg[\"max\"])\n", + " dist = field_cfg.get(\"distribution\", \"uniform\")\n", + " if dist == \"uniform\":\n", + " return np.random.randint(lo, hi + 1, size=size)\n", + " elif dist == \"normal\":\n", + " mu = (lo + hi) / 2.0\n", + " sigma = (hi - lo) / 6.0\n", + " out = np.random.normal(mu, sigma, size=size)\n", + " return np.clip(out, lo, hi).astype(int)\n", + " else:\n", + " return np.random.randint(lo, hi + 1, size=size)\n", + " elif t == \"float\":\n", + " lo, hi = float(field_cfg[\"min\"]), float(field_cfg[\"max\"])\n", + " dist = field_cfg.get(\"distribution\", \"uniform\")\n", + " if dist == \"uniform\":\n", + " return np.random.uniform(lo, hi, size=size)\n", + " elif dist == \"normal\":\n", + " mu = (lo + hi) / 2.0\n", + " sigma = (hi - lo) / 6.0\n", + " return np.clip(np.random.normal(mu, sigma, size=size), lo, hi)\n", + " elif dist == \"lognormal\":\n", + " mu = math.log(max(1e-3, (lo + hi) / 2.0))\n", + " sigma = 0.75\n", + " out = np.random.lognormal(mu, sigma, size=size)\n", + " return np.clip(out, lo, hi)\n", + " else:\n", + " return np.random.uniform(lo, hi, size=size)\n", + " else:\n", + " raise ValueError(\"Unsupported numeric type\")\n", + "\n", + "def sample_datetime(start: str, end: str, size=1, fmt=\"%Y-%m-%d %H:%M:%S\"):\n", + " s = datetime.fromisoformat(start)\n", + " e = datetime.fromisoformat(end)\n", + " total = int((e - s).total_seconds())\n", + " r = np.random.randint(0, total, size=size)\n", + " return [(s + timedelta(seconds=int(x))).strftime(fmt) for x in r]\n" + ] + }, + { + "cell_type": "markdown", + "id": "5f24111a", + "metadata": {}, + "source": [ + "## 3) Rule-based Generator" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "cd61330d", + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "def generate_rule_based(CFG: Dict[str, Any]) -> pd.DataFrame:\n", + " n = CFG[\"rows\"]\n", + " dt_cfg = CFG.get(\"datetime_range\", {\"start\":\"2024-01-01\",\"end\":\"2025-10-01\",\"fmt\":\"%Y-%m-%d %H:%M:%S\"})\n", + " data = {}\n", + " for f in CFG[\"fields\"]:\n", + " name, t = f[\"name\"], f[\"type\"]\n", + " if t == \"uuid4\":\n", + " data[name] = [str(uuid.uuid4()) for _ in range(n)]\n", + " elif t in (\"int\",\"float\"):\n", + " data[name] = sample_numeric(f, size=n)\n", + " elif t == \"enum\":\n", + " data[name] = sample_enum(f[\"values\"], f.get(\"probs\"), size=n)\n", + " elif t == \"datetime\":\n", + " data[name] = sample_datetime(dt_cfg[\"start\"], dt_cfg[\"end\"], size=n, fmt=dt_cfg[\"fmt\"])\n", + " elif t == \"bool\":\n", + " data[name] = np.random.rand(n) < 0.9 # 90% True\n", + " else:\n", + " data[name] = [None]*n\n", + " df = pd.DataFrame(data)\n", + "\n", + " # Derive NPS roughly from likert questions\n", + " if set([\"q_quality\",\"q_value\",\"q_ease\",\"q_support\"]).issubset(df.columns):\n", + " likert_avg = df[[\"q_quality\",\"q_value\",\"q_ease\",\"q_support\"]].mean(axis=1)\n", + " df[\"nps\"] = np.clip(np.round((likert_avg - 1.0) * (10.0/4.0) + np.random.normal(0, 1.2, size=n)), 0, 10).astype(int)\n", + "\n", + " # Heuristic target: is_detractor more likely when completion high & attention failed\n", + " if \"is_detractor\" in df.columns:\n", + " base = 0.25\n", + " comp = df.get(\"completion_seconds\", pd.Series(np.zeros(n)))\n", + " attn = pd.Series(df.get(\"attention_passed\", np.ones(n))).astype(bool)\n", + " boost = (comp > 900).astype(int) + (~attn).astype(int)\n", + " p = np.clip(base + 0.15*boost, 0.01, 0.95)\n", + " df[\"is_detractor\"] = np.random.rand(n) < p\n", + "\n", + " return df\n", + "\n", + "df_rule = generate_rule_based(CFG)\n", + "df_rule.head()\n" + ] + }, + { + "cell_type": "markdown", + "id": "dd9eff20", + "metadata": {}, + "source": [ + "## 4) Validation (Pandera optional)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9a4ef86a", + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "def build_pandera_schema(CFG):\n", + " if pa is None:\n", + " return None\n", + " cols = {}\n", + " for f in CFG[\"fields\"]:\n", + " t, name = f[\"type\"], f[\"name\"]\n", + " if t == \"int\": cols[name] = pa.Column(int)\n", + " elif t == \"float\": cols[name] = pa.Column(float)\n", + " elif t == \"enum\": cols[name] = pa.Column(object)\n", + " elif t == \"datetime\": cols[name] = pa.Column(object)\n", + " elif t == \"uuid4\": cols[name] = pa.Column(object)\n", + " elif t == \"bool\": cols[name] = pa.Column(bool)\n", + " else: cols[name] = pa.Column(object)\n", + " return pa.DataFrameSchema(cols) if pa is not None else None\n", + "\n", + "def validate_df(df, CFG):\n", + " schema = build_pandera_schema(CFG)\n", + " if schema is None:\n", + " return df, {\"engine\":\"basic\",\"valid_rows\": len(df), \"invalid_rows\": 0}\n", + " try:\n", + " v = schema.validate(df, lazy=True)\n", + " return v, {\"engine\":\"pandera\",\"valid_rows\": len(v), \"invalid_rows\": 0}\n", + " except Exception as e:\n", + " print(\"Validation error:\", e)\n", + " return df, {\"engine\":\"pandera\",\"valid_rows\": len(df), \"invalid_rows\": 0, \"notes\": \"Non-strict mode.\"}\n", + "\n", + "validated_rule, report_rule = validate_df(df_rule, CFG)\n", + "print(report_rule)\n", + "validated_rule.head()\n" + ] + }, + { + "cell_type": "markdown", + "id": "d5f1d93a", + "metadata": {}, + "source": [ + "## 5) Save" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "73626b4c", + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "from pathlib import Path\n", + "out = Path(\"data\"); out.mkdir(exist_ok=True)\n", + "ts = datetime.utcnow().strftime(\"%Y%m%dT%H%M%SZ\")\n", + "csv_path = out / f\"survey_rule_{ts}.csv\"\n", + "validated_rule.to_csv(csv_path, index=False)\n", + "print(\"Saved:\", csv_path.as_posix())\n" + ] + }, + { + "cell_type": "markdown", + "id": "87c89b51", + "metadata": {}, + "source": [ + "## 6) Optional: LLM Generator (JSON mode, retry & strict parsing)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "24e94771", + "metadata": {}, + "outputs": [], + "source": [ + "# Fixed LLM Generation Functions\n", + "def create_survey_prompt(CFG, n_rows=50):\n", + " \"\"\"Create a clear, structured prompt for survey data generation\"\"\"\n", + " fields_desc = []\n", + " for field in CFG['fields']:\n", + " name = field['name']\n", + " field_type = field['type']\n", + " \n", + " if field_type == 'int':\n", + " min_val = field.get('min', 0)\n", + " max_val = field.get('max', 100)\n", + " fields_desc.append(f\" - {name}: integer between {min_val} and {max_val}\")\n", + " elif field_type == 'float':\n", + " min_val = field.get('min', 0.0)\n", + " max_val = field.get('max', 100.0)\n", + " fields_desc.append(f\" - {name}: float between {min_val} and {max_val}\")\n", + " elif field_type == 'enum':\n", + " values = field.get('values', [])\n", + " fields_desc.append(f\" - {name}: one of {values}\")\n", + " elif field_type == 'bool':\n", + " fields_desc.append(f\" - {name}: boolean (true/false)\")\n", + " elif field_type == 'uuid4':\n", + " fields_desc.append(f\" - {name}: UUID string\")\n", + " elif field_type == 'datetime':\n", + " fmt = field.get('fmt', '%Y-%m-%d %H:%M:%S')\n", + " fields_desc.append(f\" - {name}: datetime string in format {fmt}\")\n", + " else:\n", + " fields_desc.append(f\" - {name}: {field_type}\")\n", + " \n", + " prompt = f\"\"\"Generate {n_rows} rows of realistic survey response data.\n", + "\n", + "Schema:\n", + "{chr(10).join(fields_desc)}\n", + "\n", + "CRITICAL REQUIREMENTS:\n", + "- Return a JSON object with a \"responses\" key containing an array\n", + "- Each object in the array must have all required fields\n", + "- Use realistic, diverse values for survey responses\n", + "- No trailing commas\n", + "- No comments or explanations\n", + "\n", + "Output format: JSON object with \"responses\" array containing exactly {n_rows} objects.\n", + "\n", + "Example structure:\n", + "{{\n", + " \"responses\": [\n", + " {{\n", + " \"response_id\": \"uuid-string\",\n", + " \"respondent_id\": 12345,\n", + " \"submitted_at\": \"2024-01-01 12:00:00\",\n", + " \"country\": \"KE\",\n", + " \"language\": \"en\",\n", + " \"device\": \"android\",\n", + " \"age\": 25,\n", + " \"gender\": \"female\",\n", + " \"education\": \"bachelor\",\n", + " \"income_band\": \"upper_mid\",\n", + " \"completion_seconds\": 300.5,\n", + " \"attention_passed\": true,\n", + " \"q_quality\": 4,\n", + " \"q_value\": 3,\n", + " \"q_ease\": 5,\n", + " \"q_support\": 4,\n", + " \"nps\": 8,\n", + " \"is_detractor\": false\n", + " }},\n", + " ...\n", + " ]\n", + "}}\n", + "\n", + "IMPORTANT: Return ONLY the JSON object with \"responses\" key, nothing else.\"\"\"\n", + " \n", + " return prompt\n", + "\n", + "def repair_truncated_json(content):\n", + " \"\"\"Attempt to repair truncated JSON responses\"\"\"\n", + " content = content.strip()\n", + " \n", + " # If it starts with { but doesn't end with }, try to close it\n", + " if content.startswith('{') and not content.endswith('}'):\n", + " # Find the last complete object in the responses array\n", + " responses_start = content.find('\"responses\": [')\n", + " if responses_start != -1:\n", + " # Find the last complete object\n", + " brace_count = 0\n", + " last_complete_pos = -1\n", + " in_string = False\n", + " escape_next = False\n", + " \n", + " for i, char in enumerate(content[responses_start:], responses_start):\n", + " if escape_next:\n", + " escape_next = False\n", + " continue\n", + " \n", + " if char == '\\\\':\n", + " escape_next = True\n", + " continue\n", + " \n", + " if char == '\"' and not escape_next:\n", + " in_string = not in_string\n", + " continue\n", + " \n", + " if not in_string:\n", + " if char == '{':\n", + " brace_count += 1\n", + " elif char == '}':\n", + " brace_count -= 1\n", + " if brace_count == 0:\n", + " last_complete_pos = i\n", + " break\n", + " \n", + " if last_complete_pos != -1:\n", + " # Truncate at the last complete object and close the JSON\n", + " repaired = content[:last_complete_pos + 1] + '\\n ]\\n}'\n", + " print(f\"🔧 Repaired JSON: truncated at position {last_complete_pos}\")\n", + " return repaired\n", + " \n", + " return content\n", + "\n", + "def fixed_llm_generate_batch(CFG, n_rows=50):\n", + " \"\"\"Fixed LLM generation with better prompt and error handling\"\"\"\n", + " if not os.getenv('OPENAI_API_KEY'):\n", + " print(\"No OpenAI API key, using rule-based fallback\")\n", + " tmp = dict(CFG); tmp['rows'] = n_rows\n", + " return generate_rule_based(tmp)\n", + " \n", + " try:\n", + " from openai import OpenAI\n", + " client = OpenAI()\n", + " \n", + " prompt = create_survey_prompt(CFG, n_rows)\n", + " \n", + " print(f\"🔄 Generating {n_rows} survey responses with LLM...\")\n", + " \n", + " # Calculate appropriate max_tokens based on batch size\n", + " # Roughly 200-300 tokens per row, with some buffer\n", + " estimated_tokens = n_rows * 300 + 500 # Buffer for JSON structure\n", + " max_tokens = min(max(estimated_tokens, 2000), 8000) # Between 2k-8k tokens\n", + " \n", + " print(f\"📊 Using max_tokens: {max_tokens} (estimated: {estimated_tokens})\")\n", + " \n", + " response = client.chat.completions.create(\n", + " model='gpt-4o-mini',\n", + " messages=[\n", + " {'role': 'system', 'content': 'You are a data generation expert. Generate realistic survey data in JSON format. Always return complete, valid JSON.'},\n", + " {'role': 'user', 'content': prompt}\n", + " ],\n", + " temperature=0.3,\n", + " max_tokens=max_tokens,\n", + " response_format={'type': 'json_object'}\n", + " )\n", + " \n", + " content = response.choices[0].message.content\n", + " print(f\"📝 Raw response length: {len(content)} characters\")\n", + " \n", + " # Check if response appears truncated\n", + " if not content.strip().endswith('}') and not content.strip().endswith(']'):\n", + " print(\"⚠️ Response appears truncated, attempting repair...\")\n", + " content = repair_truncated_json(content)\n", + " \n", + " # Try to extract JSON with improved logic\n", + " try:\n", + " data = json.loads(content)\n", + " print(f\"🔍 Parsed JSON type: {type(data)}\")\n", + " \n", + " if isinstance(data, list):\n", + " df = pd.DataFrame(data)\n", + " print(f\"📊 Direct array: {len(df)} rows\")\n", + " elif isinstance(data, dict):\n", + " # Check for common keys that might contain the data\n", + " for key in ['responses', 'rows', 'data', 'items', 'records', 'results', 'survey_responses']:\n", + " if key in data and isinstance(data[key], list):\n", + " df = pd.DataFrame(data[key])\n", + " print(f\"📊 Found data in '{key}': {len(df)} rows\")\n", + " break\n", + " else:\n", + " # If no standard key found, check if all values are lists/objects\n", + " list_keys = [k for k, v in data.items() if isinstance(v, list) and len(v) > 0]\n", + " if list_keys:\n", + " # Use the first list key found\n", + " key = list_keys[0]\n", + " df = pd.DataFrame(data[key])\n", + " print(f\"📊 Found data in '{key}': {len(df)} rows\")\n", + " else:\n", + " # Try to convert the dict values to a list\n", + " if all(isinstance(v, dict) for v in data.values()):\n", + " df = pd.DataFrame(list(data.values()))\n", + " print(f\"📊 Converted dict values: {len(df)} rows\")\n", + " else:\n", + " raise ValueError(f\"Unexpected JSON structure: {list(data.keys())}\")\n", + " else:\n", + " raise ValueError(f\"Unexpected JSON type: {type(data)}\")\n", + " \n", + " if len(df) == n_rows:\n", + " print(f\"✅ Successfully generated {len(df)} survey responses\")\n", + " return df\n", + " else:\n", + " print(f\"⚠️ Generated {len(df)} rows, expected {n_rows}\")\n", + " if len(df) > 0:\n", + " return df\n", + " else:\n", + " raise ValueError(\"No data generated\")\n", + " \n", + " except json.JSONDecodeError as e:\n", + " print(f\"❌ JSON parsing failed: {str(e)}\")\n", + " # Try the improved extract_strict_json function\n", + " try:\n", + " data = extract_strict_json(content)\n", + " df = pd.DataFrame(data)\n", + " print(f\"✅ Recovered with strict parsing: {len(df)} rows\")\n", + " return df\n", + " except Exception as e2:\n", + " print(f\"❌ Strict parsing also failed: {str(e2)}\")\n", + " # Print a sample of the content for debugging\n", + " print(f\"🔍 Content sample: {content[:500]}...\")\n", + " raise e2\n", + " \n", + " except Exception as e:\n", + " print(f'❌ LLM error, fallback to rule-based mock: {str(e)}')\n", + " tmp = dict(CFG); tmp['rows'] = n_rows\n", + " return generate_rule_based(tmp)\n", + "\n", + "def fixed_generate_llm(CFG, total_rows=200, batch_size=50):\n", + " \"\"\"Fixed LLM generation with adaptive batch processing\"\"\"\n", + " print(f\"🚀 Generating {total_rows} survey responses with adaptive batching\")\n", + " \n", + " # Adaptive batch sizing based on total rows\n", + " if total_rows <= 20:\n", + " optimal_batch_size = min(batch_size, total_rows)\n", + " elif total_rows <= 50:\n", + " optimal_batch_size = min(15, batch_size)\n", + " elif total_rows <= 100:\n", + " optimal_batch_size = min(10, batch_size)\n", + " else:\n", + " optimal_batch_size = min(8, batch_size)\n", + " \n", + " print(f\"📊 Using optimal batch size: {optimal_batch_size}\")\n", + " \n", + " all_dataframes = []\n", + " remaining = total_rows\n", + " \n", + " while remaining > 0:\n", + " current_batch_size = min(optimal_batch_size, remaining)\n", + " print(f\"\\n📦 Processing batch: {current_batch_size} rows (remaining: {remaining})\")\n", + " \n", + " try:\n", + " batch_df = fixed_llm_generate_batch(CFG, current_batch_size)\n", + " all_dataframes.append(batch_df)\n", + " remaining -= len(batch_df)\n", + " \n", + " # Small delay between batches to avoid rate limits\n", + " if remaining > 0:\n", + " time.sleep(1.5)\n", + " \n", + " except Exception as e:\n", + " print(f\"❌ Batch failed: {str(e)}\")\n", + " print(f\"🔄 Retrying with smaller batch size...\")\n", + " \n", + " # Try with smaller batch size\n", + " smaller_batch = max(1, current_batch_size // 2)\n", + " if smaller_batch < current_batch_size:\n", + " try:\n", + " print(f\"🔄 Retrying with {smaller_batch} rows...\")\n", + " batch_df = fixed_llm_generate_batch(CFG, smaller_batch)\n", + " all_dataframes.append(batch_df)\n", + " remaining -= len(batch_df)\n", + " continue\n", + " except Exception as e2:\n", + " print(f\"❌ Retry also failed: {str(e2)}\")\n", + " \n", + " print(f\"Using rule-based fallback for remaining {remaining} rows\")\n", + " fallback_df = generate_rule_based(CFG, remaining)\n", + " all_dataframes.append(fallback_df)\n", + " break\n", + " \n", + " if all_dataframes:\n", + " result = pd.concat(all_dataframes, ignore_index=True)\n", + " print(f\"✅ Generated total: {len(result)} survey responses\")\n", + " return result\n", + " else:\n", + " print(\"❌ No data generated\")\n", + " return pd.DataFrame()\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e1af410e", + "metadata": {}, + "outputs": [], + "source": [ + "# Test the fixed LLM generation\n", + "print(\"🧪 Testing LLM generation...\")\n", + "\n", + "# Test with small dataset first\n", + "test_df = fixed_llm_generate_batch(CFG, 10)\n", + "print(f\"\\n📊 Generated dataset shape: {test_df.shape}\")\n", + "print(f\"\\n📋 First few rows:\")\n", + "print(test_df.head())\n", + "print(f\"\\n📈 Data types:\")\n", + "print(test_df.dtypes)\n", + "\n", + "# Debug function to see what the LLM is actually returning\n", + "def debug_llm_response(CFG, n_rows=5):\n", + " \"\"\"Debug function to see raw LLM response\"\"\"\n", + " if not os.getenv('OPENAI_API_KEY'):\n", + " print(\"No OpenAI API key available for debugging\")\n", + " return\n", + " \n", + " try:\n", + " from openai import OpenAI\n", + " client = OpenAI()\n", + " \n", + " prompt = create_survey_prompt(CFG, n_rows)\n", + " \n", + " print(f\"\\n🔍 DEBUG: Testing with {n_rows} rows\")\n", + " print(f\"📝 Prompt length: {len(prompt)} characters\")\n", + " \n", + " response = client.chat.completions.create(\n", + " model='gpt-4o-mini',\n", + " messages=[\n", + " {'role': 'system', 'content': 'You are a data generation expert. Generate realistic survey data in JSON format.'},\n", + " {'role': 'user', 'content': prompt}\n", + " ],\n", + " temperature=0.3,\n", + " max_tokens=2000,\n", + " response_format={'type': 'json_object'}\n", + " )\n", + " \n", + " content = response.choices[0].message.content\n", + " print(f\"📝 Raw response length: {len(content)} characters\")\n", + " print(f\"🔍 First 200 characters: {content[:200]}\")\n", + " print(f\"🔍 Last 200 characters: {content[-200:]}\")\n", + " \n", + " # Try to parse\n", + " try:\n", + " data = json.loads(content)\n", + " print(f\"✅ JSON parsed successfully\")\n", + " print(f\"🔍 Data type: {type(data)}\")\n", + " if isinstance(data, dict):\n", + " print(f\"🔍 Dict keys: {list(data.keys())}\")\n", + " elif isinstance(data, list):\n", + " print(f\"🔍 List length: {len(data)}\")\n", + " except Exception as e:\n", + " print(f\"❌ JSON parsing failed: {str(e)}\")\n", + " \n", + " except Exception as e:\n", + " print(f\"❌ Debug failed: {str(e)}\")\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "75c90739", + "metadata": {}, + "outputs": [], + "source": [ + "# Test the fixed implementation\n", + "print(\"🧪 Testing the fixed LLM generation...\")\n", + "\n", + "# Test with small dataset\n", + "test_df = fixed_llm_generate_batch(CFG, 5)\n", + "print(f\"\\n📊 Generated dataset shape: {test_df.shape}\")\n", + "print(f\"\\n📋 First few rows:\")\n", + "print(test_df.head())\n", + "print(f\"\\n📈 Data types:\")\n", + "print(test_df.dtypes)\n", + "\n", + "if not test_df.empty:\n", + " print(f\"\\n✅ SUCCESS! LLM generation is now working!\")\n", + " print(f\"📊 Generated {len(test_df)} survey responses using LLM\")\n", + "else:\n", + " print(f\"\\n❌ Still having issues with LLM generation\")\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "dd83b842", + "metadata": {}, + "outputs": [], + "source": [ + "#Test larger dataset generation \n", + "print(\"🚀 Testing larger dataset generation...\")\n", + "large_df = fixed_generate_llm(CFG, total_rows=100, batch_size=25)\n", + "if not large_df.empty:\n", + " print(f\"\\n📊 Large dataset shape: {large_df.shape}\")\n", + " print(f\"\\n📈 Summary statistics:\")\n", + " print(large_df.describe())\n", + " \n", + " # Save the results\n", + " from pathlib import Path\n", + " out = Path(\"data\"); out.mkdir(exist_ok=True)\n", + " ts = datetime.utcnow().strftime(\"%Y%m%dT%H%M%SZ\")\n", + " csv_path = out / f\"survey_llm_fixed_{ts}.csv\"\n", + " large_df.to_csv(csv_path, index=False)\n", + " print(f\"💾 Saved: {csv_path}\")\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6029d3e2", + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "def build_json_schema(CFG):\n", + " schema = {'type':'array','items':{'type':'object','properties':{},'required':[]}}\n", + " props = schema['items']['properties']; req = schema['items']['required']\n", + " for f in CFG['fields']:\n", + " name, t = f['name'], f['type']\n", + " req.append(name)\n", + " if t in ('int','float'): props[name] = {'type':'number' if t=='float' else 'integer'}\n", + " elif t == 'enum': props[name] = {'type':'string','enum': f['values']}\n", + " elif t in ('uuid4','datetime'): props[name] = {'type':'string'}\n", + " elif t == 'bool': props[name] = {'type':'boolean'}\n", + " else: props[name] = {'type':'string'}\n", + " return schema\n", + "\n", + "PROMPT_PREAMBLE = (\n", + " \"You are a data generator. Return ONLY JSON. \"\n", + " \"Respond as a JSON object with key 'rows' whose value is an array of exactly N objects. \"\n", + " \"No prose, no code fences, no trailing commas.\"\n", + ")\n", + "\n", + "def render_prompt(CFG, n_rows=100):\n", + " minimal_cfg = {'fields': []}\n", + " for f in CFG['fields']:\n", + " base = {k: f[k] for k in ['name','type'] if k in f}\n", + " if 'min' in f and 'max' in f: base.update({'min': f['min'], 'max': f['max']})\n", + " if 'values' in f: base.update({'values': f['values']})\n", + " if 'fmt' in f: base.update({'fmt': f['fmt']})\n", + " minimal_cfg['fields'].append(base)\n", + " return {\n", + " 'preamble': PROMPT_PREAMBLE,\n", + " 'n_rows': n_rows,\n", + " 'schema': build_json_schema(CFG),\n", + " 'constraints': minimal_cfg,\n", + " 'instruction': f\"Return ONLY this structure: {{'rows': [ ... exactly {n_rows} objects ... ]}}\"\n", + " }\n", + "\n", + "def parse_llm_json_to_df(raw: str) -> pd.DataFrame:\n", + " try:\n", + " obj = json.loads(raw)\n", + " if isinstance(obj, dict) and isinstance(obj.get('rows'), list):\n", + " return pd.DataFrame(obj['rows'])\n", + " except Exception:\n", + " pass\n", + " data = extract_strict_json(raw)\n", + " return pd.DataFrame(data)\n", + "\n", + "USE_LLM = bool(os.getenv('OPENAI_API_KEY'))\n", + "print('LLM available:', USE_LLM)\n", + "\n", + "def llm_generate_batch(CFG, n_rows=50):\n", + " if USE_LLM:\n", + " try:\n", + " from openai import OpenAI\n", + " client = OpenAI()\n", + " prompt = json.dumps(render_prompt(CFG, n_rows))\n", + " resp = client.chat.completions.create(\n", + " model='gpt-4o-mini',\n", + " response_format={'type': 'json_object'},\n", + " messages=[\n", + " {'role':'system','content':'You output strict JSON only.'},\n", + " {'role':'user','content': prompt}\n", + " ],\n", + " temperature=0.2,\n", + " max_tokens=8192,\n", + " )\n", + " raw = resp.choices[0].message.content\n", + " try:\n", + " return parse_llm_json_to_df(raw)\n", + " except Exception:\n", + " stricter = (\n", + " prompt\n", + " + \"\\nReturn ONLY a JSON object structured as: \"\n", + " + \"{\\\"rows\\\": [ ... exactly N objects ... ]}. \"\n", + " + \"No prose, no explanations.\"\n", + " )\n", + " resp2 = client.chat.completions.create(\n", + " model='gpt-4o-mini',\n", + " response_format={'type': 'json_object'},\n", + " messages=[\n", + " {'role':'system','content':'You output strict JSON only.'},\n", + " {'role':'user','content': stricter}\n", + " ],\n", + " temperature=0.2,\n", + " max_tokens=8192,\n", + " )\n", + " raw2 = resp2.choices[0].message.content\n", + " return parse_llm_json_to_df(raw2)\n", + " except Exception as e:\n", + " print('LLM error, fallback to rule-based mock:', e)\n", + " tmp = dict(CFG); tmp['rows'] = n_rows\n", + " return generate_rule_based(tmp)\n", + "\n", + "def generate_llm(CFG, total_rows=200, batch_size=50):\n", + " dfs = []; remaining = total_rows\n", + " while remaining > 0:\n", + " b = min(batch_size, remaining)\n", + " dfs.append(llm_generate_batch(CFG, n_rows=b))\n", + " remaining -= b\n", + " time.sleep(0.2)\n", + " return pd.concat(dfs, ignore_index=True)\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2e759087", + "metadata": {}, + "outputs": [], + "source": [ + "df_llm = generate_llm(CFG, total_rows=100, batch_size=50)\n", + "df_llm.head()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6d4908ad", + "metadata": {}, + "outputs": [], + "source": [ + "# Test the improved LLM generation with adaptive batching\n", + "print(\"🧪 Testing improved LLM generation with adaptive batching...\")\n", + "\n", + "# Test with smaller dataset first\n", + "print(\"\\n📦 Testing small batch (10 rows)...\")\n", + "small_df = fixed_llm_generate_batch(CFG, 10)\n", + "print(f\"✅ Small batch result: {len(small_df)} rows\")\n", + "\n", + "# Test with medium dataset using adaptive batching\n", + "print(\"\\n📦 Testing medium dataset (30 rows) with adaptive batching...\")\n", + "medium_df = fixed_generate_llm(CFG, total_rows=30, batch_size=15)\n", + "print(f\"✅ Medium dataset result: {len(medium_df)} rows\")\n", + "\n", + "if not medium_df.empty:\n", + " print(f\"\\n📊 Dataset shape: {medium_df.shape}\")\n", + " print(f\"\\n📋 First few rows:\")\n", + " print(medium_df.head())\n", + " \n", + " # Save the results\n", + " from pathlib import Path\n", + " out = Path(\"data\"); out.mkdir(exist_ok=True)\n", + " ts = datetime.utcnow().strftime(\"%Y%m%dT%H%M%SZ\")\n", + " csv_path = out / f\"survey_adaptive_batch_{ts}.csv\"\n", + " medium_df.to_csv(csv_path, index=False)\n", + " print(f\"💾 Saved: {csv_path}\")\n", + "else:\n", + " print(\"❌ Medium dataset generation failed\")\n" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": ".venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.12" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/week3/community-contributions/week3_exercise_solution-Stephen.ipynb b/week3/community-contributions/week3_exercise_solution-Stephen.ipynb new file mode 100644 index 0000000..bbc99e7 --- /dev/null +++ b/week3/community-contributions/week3_exercise_solution-Stephen.ipynb @@ -0,0 +1,216 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "c58e628f", + "metadata": {}, + "source": [ + "\n", + "## **Week 3 task.**\n", + "Create your own tool that generates synthetic data/test data. Input the type of dataset or products or job postings, etc. and let the tool dream up various data samples.\n", + "\n", + "https://colab.research.google.com/drive/13wR4Blz3Ot_x0GOpflmvvFffm5XU3Kct?usp=sharing" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "0ddde9ed", + "metadata": {}, + "outputs": [], + "source": [ + "# imports\n", + "\n", + "import os\n", + "import requests\n", + "import torch\n", + "from IPython.display import Markdown, display, update_display\n", + "from openai import OpenAI\n", + "from huggingface_hub import login\n", + "from huggingface_hub import login\n", + "from transformers import AutoTokenizer, AutoModelForCausalLM, TextStreamer, BitsAndBytesConfig\n", + "from dotenv import load_dotenv\n", + "import gradio as gr" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "cbbc6cc8", + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "load_dotenv(override=True)\n", + "\n", + "openai_api_key = os.getenv('OPENAI_API_KEY')\n", + "llama_api_key = \"ollama\"\n", + "\n", + "# hf_token = userdata.get('HF_TOKEN')\n", + "# login(hf_token, add_to_git_credential=True)\n", + "\n", + "\n", + "if openai_api_key:\n", + " print(f\"OpenAI API Key exists and begins {openai_api_key[:8]}\")\n", + "else:\n", + " print(\"OpenAI API Key not set\")\n", + "\n", + "if llama_api_key:\n", + " print(f\"LLama API Key exists\")\n", + "else:\n", + " print(\"LLama API Key not set\")\n", + " \n", + "GPT_MODEL = \"gpt-4.1-mini\"\n", + "LLAMA_MODEL = \"llama3.1\"\n", + "\n", + "\n", + "openai = OpenAI()\n", + "\n", + "llama_url = \"http://localhost:11434/v1\"\n", + "llama = OpenAI(api_key=llama_api_key, base_url=llama_url)" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "ef083ec6", + "metadata": {}, + "outputs": [], + "source": [ + "def generate_with_gpt(user_prompt: str, num_samples: int = 5):\n", + " \"\"\"\n", + " Generates synthetic data using OpenAI's GPT.\n", + " Return a JSON string.\n", + " \"\"\"\n", + " if not openai:\n", + " return json.dumps({\"error\": \"OpenAI client not initialized. Please check your API key.\"}, indent=2)\n", + "\n", + " try:\n", + " response = openai.chat.completions.create(\n", + " model=GPT_MODEL,\n", + " messages=[\n", + " {\"role\": \"system\", \"content\": f\"You are a data generation assistant. Generate a JSON array of exactly {num_samples} objects based on the user's request. The output must be valid JSON only, without any other text or formatting.\"},\n", + " {\"role\": \"user\", \"content\": user_prompt}\n", + " ],\n", + " response_format={\"type\": \"json_object\"}\n", + " )\n", + " \n", + " json_text = response.choices[0].message.content\n", + " return json_text\n", + " except APIError as e:\n", + " return json.dumps({\"error\": f\"Error from OpenAI API: {e.body}\"}, indent=2)\n", + " except Exception as e:\n", + " return json.dumps({\"error\": f\"An unexpected error occurred: {e}\"}, indent=2)\n", + "\n", + "def generate_with_gpt(user_prompt: str, num_samples: int = 5):\n", + " \"\"\"\n", + " Generates synthetic data using OpenAI's GPT.\n", + " Return a JSON string.\n", + " \"\"\"\n", + " if not openai:\n", + " return json.dumps({\"error\": \"OpenAI client not initialized. Please check your API key.\"}, indent=2)\n", + "\n", + " try:\n", + " response = openai.chat.completions.create(\n", + " model=GPT_MODEL,\n", + " messages=[\n", + " {\"role\": \"system\", \"content\": f\"You are a data generation assistant. Generate a JSON array of exactly {num_samples} objects based on the user's request. The output must be valid JSON only, without any other text or formatting.\"},\n", + " {\"role\": \"user\", \"content\": user_prompt}\n", + " ],\n", + " response_format={\"type\": \"json_object\"}\n", + " )\n", + " \n", + " json_text = response.choices[0].message.content\n", + " return json_text\n", + " except APIError as e:\n", + " return json.dumps({\"error\": f\"Error from OpenAI API: {e.body}\"}, indent=2)\n", + " except Exception as e:\n", + " return json.dumps({\"error\": f\"An unexpected error occurred: {e}\"}, indent=2)" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "b98f84d8", + "metadata": {}, + "outputs": [], + "source": [ + "def generate_data(user_prompt, model_choice):\n", + " \"\"\"\n", + " Wrapper function that calls the appropriate generation function based on model choice.\n", + " \"\"\"\n", + " if not user_prompt:\n", + " return json.dumps({\"error\": \"Please provide a description for the data.\"}, indent=2)\n", + "\n", + " if model_choice == f\"Hugging Face ({LLAMA_MODEL})\":\n", + " return generate_with_llama(user_prompt)\n", + " elif model_choice == f\"OpenAI ({GPT_MODEL})\":\n", + " return generate_with_gpt(user_prompt)\n", + " else:\n", + " return json.dumps({\"error\": \"Invalid model choice.\"}, indent=2)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "adbc19a8", + "metadata": {}, + "outputs": [], + "source": [ + "# Gradio UI\n", + "with gr.Blocks(theme=gr.themes.Glass(), title=\"Synthetic Data Generator\") as ui:\n", + " gr.Markdown(\"# Synthetic Data Generator\")\n", + " gr.Markdown(\"Describe the type of data you need, select a model, and click 'Generate'.\")\n", + "\n", + " with gr.Row():\n", + " with gr.Column(scale=3):\n", + " data_prompt = gr.Textbox(\n", + " lines=5,\n", + " label=\"Data Prompt\",\n", + " placeholder=\"e.g., a list of customer profiles with name, email, and a favorite product\"\n", + " )\n", + " \n", + " with gr.Column(scale=1):\n", + " model_choice = gr.Radio(\n", + " [f\"Hugging Face ({LLAMA_MODEL})\", f\"OpenAI ({GPT_MODEL})\"],\n", + " label=\"Choose a Model\",\n", + " value=f\"Hugging Face ({LLAMA_MODEL})\"\n", + " )\n", + " \n", + " generate_btn = gr.Button(\"Generate Data\")\n", + " \n", + " with gr.Row():\n", + " output_json = gr.JSON(label=\"Generated Data\")\n", + " \n", + " generate_btn.click(\n", + " fn=generate_data,\n", + " inputs=[data_prompt, model_choice],\n", + " outputs=output_json\n", + " )\n", + "\n", + "ui.launch(inbrowser=True, debug=True)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": ".venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.12" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/week4/community-contributions/Exercise_week4_jom.ipynb b/week4/community-contributions/Exercise_week4_jom.ipynb new file mode 100644 index 0000000..79704d2 --- /dev/null +++ b/week4/community-contributions/Exercise_week4_jom.ipynb @@ -0,0 +1,264 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "fee27f39", + "metadata": {}, + "outputs": [], + "source": [ + "# imports\n", + "\n", + "import os\n", + "from dotenv import load_dotenv\n", + "from openai import OpenAI\n", + "import gradio as gr\n", + "\n", + "load_dotenv(override=True)\n", + "openai_api_key = os.getenv('OPENAI_API_KEY')\n", + "anthropic_api_key = os.getenv('ANTHROPIC_API_KEY')\n", + "google_api_key = os.getenv('GOOGLE_API_KEY')\n", + "ollama_api_key = os.getenv('OLLAMA_API_KEY')\n", + "\n", + "if openai_api_key:\n", + " print(f\"OpenAI API Key exists and begins {openai_api_key[:8]}\")\n", + "else:\n", + " print(\"OpenAI API Key not set\")\n", + " \n", + "if anthropic_api_key:\n", + " print(f\"Anthropic API Key exists and begins {anthropic_api_key[:7]}\")\n", + "else:\n", + " print(\"Anthropic API Key not set (and this is optional)\")\n", + "\n", + "if google_api_key:\n", + " print(f\"Google API Key exists and begins {google_api_key[:2]}\")\n", + "else:\n", + " print(\"Google API Key not set (and this is optional)\")\n", + "\n", + "if ollama_api_key:\n", + " print(f\"OLLAMA API Key exists and begins {ollama_api_key[:2]}\")\n", + "else:\n", + " print(\"OLLAMA API Key not set (and this is optional)\")\n", + "\n", + "# Connect to client libraries\n", + "\n", + "openai = OpenAI()\n", + "\n", + "anthropic_url = \"https://api.anthropic.com/v1/\"\n", + "gemini_url = \"https://generativelanguage.googleapis.com/v1beta/openai/\"\n", + "ollama_url = \"http://localhost:11434/v1\"\n", + "\n", + "anthropic = OpenAI(api_key=anthropic_api_key, base_url=anthropic_url)\n", + "gemini = OpenAI(api_key=google_api_key, base_url=gemini_url)\n", + "ollama = OpenAI(api_key=ollama_api_key, base_url=ollama_url)\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d26f4175", + "metadata": {}, + "outputs": [], + "source": [ + "models = [\"gpt-5\", \"claude-sonnet-4-5-20250929\", \"gemini-2.5-pro\", \"gpt-oss:20b-cloud\", ]\n", + "\n", + "clients = {\"gpt-5\": openai, \"claude-sonnet-4-5-20250929\": anthropic, \"gemini-2.5-pro\": gemini, \"gpt-oss:20b-cloud\": ollama}\n", + "\n", + "# Want to keep costs ultra-low? Replace this with models of your choice, using the examples from yesterday" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "76563884", + "metadata": {}, + "outputs": [], + "source": [ + "system_prompt_doc = \"\"\"You are an expert Python developer and code reviewer.\n", + "Your job is to read the user's provided function, and return:\n", + "1. A concise, PEP-257-compliant docstring summarizing what the function does, clarifying types, parameters, return values, and side effects.\n", + "2. Helpful inline comments that improve both readability and maintainability, without restating what the code obviously does.\n", + "\n", + "Only output the function, not explanations or additional text. \n", + "Do not modify variable names or refactor the function logic.\n", + "Your response should improve the code's clarity and documentation, making it easier for others to understand and maintain.\n", + "Don't be extremely verbose.\n", + "Your answer should be at a {level} level of expertise.\n", + "\"\"\"\n", + "\n", + "system_prompt_tests = \"\"\"You are a seasoned Python developer and testing expert.\n", + "Your task is to read the user's provided function, and generate:\n", + "1. A concise set of meaningful unit tests that thoroughly validate the function's correctness, including typical, edge, and error cases.\n", + "2. The tests should be written for pytest (or unittest if pytest is not appropriate), use clear, descriptive names, and avoid unnecessary complexity.\n", + "3. If dependencies or mocking are needed, include minimal necessary setup code (but avoid over-mocking).\n", + "\n", + "Only output the relevant test code, not explanations or extra text.\n", + "Do not change the original function; focus solely on comprehensive, maintainable test coverage that other developers can easily understand and extend.\n", + "\"\"\"\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1bd82e96", + "metadata": {}, + "outputs": [], + "source": [ + "def generate_documentation(code, model, level):\n", + " response = clients[model].chat.completions.create(\n", + " model=model,\n", + " messages=[\n", + " {\"role\": \"system\", \"content\": system_prompt_doc.format(level=level)},\n", + " {\"role\": \"user\", \"content\": code}\n", + " ],\n", + " stream=True\n", + " )\n", + " output = \"\"\n", + " for chunk in response:\n", + " output += chunk.choices[0].delta.content or \"\"\n", + " yield output.replace(\"```python\", \"\").replace(\"```\", \"\")\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b01b3421", + "metadata": {}, + "outputs": [], + "source": [ + "def generate_tests(code, model ):\n", + " response = clients[model].chat.completions.create(\n", + " model=model,\n", + " messages=[\n", + " {\"role\": \"system\", \"content\": system_prompt_tests},\n", + " {\"role\": \"user\", \"content\": code}\n", + " ],\n", + " stream=True\n", + " )\n", + " output = \"\"\n", + " for chunk in response:\n", + " output += chunk.choices[0].delta.content or \"\"\n", + " yield output.replace(\"```python\", \"\").replace(\"```\", \"\")\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "16b71915", + "metadata": {}, + "outputs": [], + "source": [ + "vscode_dark = gr.themes.Monochrome(\n", + " primary_hue=\"blue\",\n", + " secondary_hue=\"slate\",\n", + " neutral_hue=\"slate\",\n", + ").set(\n", + " body_background_fill=\"#1e1e1e\",\n", + " body_background_fill_dark=\"#1e1e1e\",\n", + " block_background_fill=\"#252526\",\n", + " block_background_fill_dark=\"#252526\",\n", + " block_border_color=\"#3e3e42\",\n", + " block_border_color_dark=\"#3e3e42\",\n", + " border_color_primary=\"#3e3e42\",\n", + " block_label_background_fill=\"#252526\",\n", + " block_label_background_fill_dark=\"#252526\",\n", + " block_label_text_color=\"#cccccc\",\n", + " block_label_text_color_dark=\"#cccccc\",\n", + " block_title_text_color=\"#cccccc\",\n", + " block_title_text_color_dark=\"#cccccc\",\n", + " body_text_color=\"#d4d4d4\",\n", + " body_text_color_dark=\"#d4d4d4\",\n", + " button_primary_background_fill=\"#0e639c\",\n", + " button_primary_background_fill_dark=\"#0e639c\",\n", + " button_primary_background_fill_hover=\"#1177bb\",\n", + " button_primary_background_fill_hover_dark=\"#1177bb\",\n", + " button_primary_text_color=\"#ffffff\",\n", + " button_primary_text_color_dark=\"#ffffff\",\n", + " input_background_fill=\"#3c3c3c\",\n", + " input_background_fill_dark=\"#3c3c3c\",\n", + " color_accent=\"#007acc\",\n", + " color_accent_soft=\"#094771\",\n", + ")\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "23311022", + "metadata": {}, + "outputs": [], + "source": [ + "import gradio as gr\n", + "\n", + "with gr.Blocks(theme=vscode_dark, css=\"\"\"\n", + " .gradio-container {font-family: 'Consolas', 'Monaco', monospace;}\n", + " h1 {color: #d4d4d4 !important;}\n", + "\"\"\") as ui:\n", + " gr.Markdown(\"# 🧑‍💻 Python Code Reviewer & Test Generator\", elem_id=\"app-title\")\n", + " with gr.Tab(\"Docstring & Comments\") as tab1:\n", + " gr.Markdown(\"# Function Docstring & Comment Helper\\nPaste your function below and get helpful docstrings and inline comments!\")\n", + "\n", + " with gr.Row():\n", + " code_input_1 = gr.Code(label=\"Paste your Python function here\", lines=10, language=\"python\")\n", + " code_output = gr.Code(label=\"Function with improved docstring and comments\", lines=10, language=\"python\")\n", + " \n", + " with gr.Row(equal_height=True):\n", + " level_radio = gr.Radio(choices=[\"Junior\", \"Mid\", \"Senior\"], value=\"Mid\", label=\"Reviewer level\", interactive=True)\n", + " model_dropdown = gr.Dropdown(choices=models, value=models[-1], label=\"Select model\")\n", + " submit_doc_btn = gr.Button(\"Generate docstring & comments\", scale=0.5)\n", + "\n", + " submit_doc_btn.click(\n", + " generate_documentation, \n", + " inputs=[code_input_1, model_dropdown, level_radio], \n", + " outputs=code_output\n", + " )\n", + "\n", + " with gr.Tab(\"Unit Tests\") as tab2:\n", + " gr.Markdown(\"# Unit Test Generator\\nPaste your function below and get auto-generated unit tests!\")\n", + "\n", + " with gr.Row():\n", + " code_input_2 = gr.Code(label=\"Paste your Python function here\", lines=10, language=\"python\")\n", + " code_output_2 = gr.Code(label=\"Generated tests\", lines=10, language=\"python\")\n", + " \n", + " with gr.Row(equal_height=True):\n", + " model_dropdown_2 = gr.Dropdown(choices=models, value=models[-1], label=\"Select model\")\n", + " submit_test_btn = gr.Button(\"Generate unit tests\", scale=0.5)\n", + "\n", + " submit_test_btn.click(\n", + " generate_tests, \n", + " inputs=[code_input_2, model_dropdown_2], \n", + " outputs=code_output_2\n", + " )\n", + " \n", + " tab2.select(lambda x: x, inputs=code_input_1, outputs=code_input_2)\n", + " tab1.select(lambda x: x, inputs=code_input_2, outputs=code_input_1)\n", + "\n", + "ui.launch(share=False, inbrowser=True)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": ".venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.8" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/week4/community-contributions/bharat_puri/docstring_generator.ipynb b/week4/community-contributions/bharat_puri/docstring_generator.ipynb new file mode 100644 index 0000000..8f17a08 --- /dev/null +++ b/week4/community-contributions/bharat_puri/docstring_generator.ipynb @@ -0,0 +1,596 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "4a6ab9a2-28a2-445d-8512-a0dc8d1b54e9", + "metadata": {}, + "source": [ + "# Code DocString / Comment Generator\n", + "\n", + "Submitted By : Bharat Puri\n", + "\n", + "Goal: Build a code tool that scans Python modules, finds functions/classes\n", + "without docstrings, and uses an LLM (Claude / GPT / Gemini / Qwen etc.)\n", + "to generate high-quality Google or NumPy style docstrings." + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "e610bf56-a46e-4aff-8de1-ab49d62b1ad3", + "metadata": {}, + "outputs": [], + "source": [ + "# imports\n", + "\n", + "import os\n", + "import io\n", + "import sys\n", + "import re\n", + "from dotenv import load_dotenv\n", + "import sys\n", + "sys.path.append(os.path.abspath(os.path.join(\"..\", \"..\"))) \n", + "from openai import OpenAI\n", + "import gradio as gr\n", + "import subprocess\n", + "from IPython.display import Markdown, display\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4f672e1c-87e9-4865-b760-370fa605e614", + "metadata": {}, + "outputs": [], + "source": [ + "load_dotenv(override=True)\n", + "openai_api_key = os.getenv('OPENAI_API_KEY')\n", + "anthropic_api_key = os.getenv('ANTHROPIC_API_KEY')\n", + "google_api_key = os.getenv('GOOGLE_API_KEY')\n", + "grok_api_key = os.getenv('GROK_API_KEY')\n", + "groq_api_key = os.getenv('GROQ_API_KEY')\n", + "openrouter_api_key = os.getenv('OPENROUTER_API_KEY')\n", + "\n", + "if openai_api_key:\n", + " print(f\"OpenAI API Key exists and begins {openai_api_key[:8]}\")\n", + "else:\n", + " print(\"OpenAI API Key not set\")\n", + " \n", + "if anthropic_api_key:\n", + " print(f\"Anthropic API Key exists and begins {anthropic_api_key[:7]}\")\n", + "else:\n", + " print(\"Anthropic API Key not set (and this is optional)\")\n", + "\n", + "if google_api_key:\n", + " print(f\"Google API Key exists and begins {google_api_key[:2]}\")\n", + "else:\n", + " print(\"Google API Key not set (and this is optional)\")\n", + "\n", + "if grok_api_key:\n", + " print(f\"Grok API Key exists and begins {grok_api_key[:4]}\")\n", + "else:\n", + " print(\"Grok API Key not set (and this is optional)\")\n", + "\n", + "if groq_api_key:\n", + " print(f\"Groq API Key exists and begins {groq_api_key[:4]}\")\n", + "else:\n", + " print(\"Groq API Key not set (and this is optional)\")\n", + "\n", + "if openrouter_api_key:\n", + " print(f\"OpenRouter API Key exists and begins {openrouter_api_key[:6]}\")\n", + "else:\n", + " print(\"OpenRouter API Key not set (and this is optional)\")\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "59863df1", + "metadata": {}, + "outputs": [], + "source": [ + "# Connect to client libraries\n", + "\n", + "openai = OpenAI()\n", + "\n", + "anthropic_url = \"https://api.anthropic.com/v1/\"\n", + "gemini_url = \"https://generativelanguage.googleapis.com/v1beta/openai/\"\n", + "grok_url = \"https://api.x.ai/v1\"\n", + "groq_url = \"https://api.groq.com/openai/v1\"\n", + "ollama_url = \"http://localhost:11434/v1\"\n", + "openrouter_url = \"https://openrouter.ai/api/v1\"\n", + "\n", + "anthropic = OpenAI(api_key=anthropic_api_key, base_url=anthropic_url)\n", + "gemini = OpenAI(api_key=google_api_key, base_url=gemini_url)\n", + "grok = OpenAI(api_key=grok_api_key, base_url=grok_url)\n", + "groq = OpenAI(api_key=groq_api_key, base_url=groq_url)\n", + "ollama = OpenAI(api_key=\"ollama\", base_url=ollama_url)\n", + "openrouter = OpenAI(api_key=openrouter_api_key, base_url=openrouter_url)\n", + "\n", + "MODEL = os.getenv(\"DOCGEN_MODEL\", \"gpt-4o-mini\")\n", + "\n", + "\n", + "# Registry for multiple model providers\n", + "MODEL_REGISTRY = {\n", + " \"gpt-4o-mini (OpenAI)\": {\n", + " \"provider\": \"openai\",\n", + " \"model\": \"gpt-4o-mini\",\n", + " },\n", + " \"gpt-4o (OpenAI)\": {\n", + " \"provider\": \"openai\",\n", + " \"model\": \"gpt-4o\",\n", + " },\n", + " \"claude-3.5-sonnet (Anthropic)\": {\n", + " \"provider\": \"anthropic\",\n", + " \"model\": \"claude-3.5-sonnet\",\n", + " },\n", + " \"gemini-1.5-pro (Google)\": {\n", + " \"provider\": \"google\",\n", + " \"model\": \"gemini-1.5-pro\",\n", + " },\n", + " \"codellama-7b (Open Source)\": {\n", + " \"provider\": \"open_source\",\n", + " \"model\": \"codellama-7b\",\n", + " },\n", + " \"starcoder2 (Open Source)\": {\n", + " \"provider\": \"open_source\",\n", + " \"model\": \"starcoder2\",\n", + " },\n", + "}\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "8aa149ed-9298-4d69-8fe2-8f5de0f667da", + "metadata": {}, + "outputs": [], + "source": [ + "models = [\"gpt-5\", \"claude-sonnet-4-5-20250929\", \"grok-4\", \"gemini-2.5-pro\", \"qwen2.5-coder\", \"deepseek-coder-v2\", \"gpt-oss:20b\", \"qwen/qwen3-coder-30b-a3b-instruct\", \"openai/gpt-oss-120b\", ]\n", + "\n", + "clients = {\"gpt-5\": openai, \"claude-sonnet-4-5-20250929\": anthropic, \"grok-4\": grok, \"gemini-2.5-pro\": gemini, \"openai/gpt-oss-120b\": groq, \"qwen2.5-coder\": ollama, \"deepseek-coder-v2\": ollama, \"gpt-oss:20b\": ollama, \"qwen/qwen3-coder-30b-a3b-instruct\": openrouter}\n", + "\n", + "# Want to keep costs ultra-low? Replace this with models of your choice, using the examples from yesterday" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "17b7d074-b1a4-4673-adec-918f82a4eff0", + "metadata": {}, + "outputs": [], + "source": [ + "# ================================================================\n", + "# Prompt Templates and Utilities\n", + "# ================================================================\n", + "\n", + "DOCSTYLE_TEMPLATES = {\n", + " \"google\": (\n", + " \"You will write a concise Google-style Python docstring for the given function or class.\\n\"\n", + " \"Rules:\\n\"\n", + " \"- One-line summary followed by short details.\\n\"\n", + " \"- Include Args:, Returns:, Raises: only if relevant.\\n\"\n", + " \"- Keep under 12 lines, no code fences or markdown formatting.\\n\"\n", + " \"Return ONLY the text between triple quotes.\"\n", + " ),\n", + "}\n", + "\n", + "SYSTEM_PROMPT = (\n", + " \"You are a senior Python engineer and technical writer. \"\n", + " \"Write precise, helpful docstrings.\"\n", + ")\n", + "\n", + "\n", + "def make_user_prompt(style: str, module_name: str, signature: str, code_context: str) -> str:\n", + " \"\"\"Build the user message for the model based on template and context.\"\"\"\n", + " instr = DOCSTYLE_TEMPLATES.get(style, DOCSTYLE_TEMPLATES[\"google\"])\n", + " prompt = (\n", + " f\"{instr}\\n\\n\"\n", + " f\"Module: {module_name}\\n\"\n", + " f\"Signature:\\n{signature}\\n\\n\"\n", + " f\"Code context:\\n{code_context}\\n\\n\"\n", + " \"Return ONLY a triple-quoted docstring, for example:\\n\"\n", + " '\"\"\"One-line summary.\\n\\n'\n", + " \"Args:\\n\"\n", + " \" x: Description\\n\"\n", + " \"Returns:\\n\"\n", + " \" y: Description\\n\"\n", + " '\"\"\"'\n", + " )\n", + " return prompt\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "16b3c10f-f7bc-4a2f-a22f-65c6807b7574", + "metadata": {}, + "outputs": [], + "source": [ + "# ================================================================\n", + "# LLM Chat Helper — OpenAI GPT\n", + "# ================================================================\n", + "def llm_generate_docstring(signature: str, context: str, style: str = \"google\", \n", + " module_name: str = \"module\", model_choice: str = \"gpt-4o-mini (OpenAI)\") -> str:\n", + " \"\"\"\n", + " Generate a Python docstring using the selected model provider.\n", + " \"\"\"\n", + " user_prompt = make_user_prompt(style, module_name, signature, context)\n", + " model_info = MODEL_REGISTRY.get(model_choice, MODEL_REGISTRY[\"gpt-4o-mini (OpenAI)\"])\n", + "\n", + " provider = model_info[\"provider\"]\n", + " model_name = model_info[\"model\"]\n", + "\n", + " if provider == \"openai\":\n", + " response = openai.chat.completions.create(\n", + " model=model_name,\n", + " temperature=0.2,\n", + " messages=[\n", + " {\"role\": \"system\", \"content\": \"You are a senior Python engineer and technical writer.\"},\n", + " {\"role\": \"user\", \"content\": user_prompt},\n", + " ],\n", + " )\n", + " text = response.choices[0].message.content.strip()\n", + "\n", + " elif provider == \"anthropic\":\n", + " # Future: integrate Anthropic SDK\n", + " text = \"Claude response simulation: \" + user_prompt[:200]\n", + "\n", + " elif provider == \"google\":\n", + " # Future: integrate Gemini API\n", + " text = \"Gemini response simulation: \" + user_prompt[:200]\n", + "\n", + " else:\n", + " # Simulated open-source fallback\n", + " text = f\"[Simulated output from {model_name}]\\nAuto-generated docstring for {signature}\"\n", + "\n", + " import re\n", + " match = re.search(r'\"\"\"(.*?)\"\"\"', text, re.S)\n", + " return match.group(1).strip() if match else text\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "82da91ac-e563-4425-8b45-1b94880d342f", + "metadata": {}, + "outputs": [], + "source": [ + "# ================================================================\n", + "# 🧱 AST Parsing Utilities — find missing docstrings\n", + "# ================================================================\n", + "import ast\n", + "\n", + "def node_signature(node: ast.AST) -> str:\n", + " \"\"\"\n", + " Build a readable signature string from a FunctionDef or ClassDef node.\n", + " Example: def add(x, y) -> int:\n", + " \"\"\"\n", + " if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)):\n", + " args = [a.arg for a in node.args.args]\n", + " if node.args.vararg:\n", + " args.append(\"*\" + node.args.vararg.arg)\n", + " for a in node.args.kwonlyargs:\n", + " args.append(a.arg + \"=?\")\n", + " if node.args.kwarg:\n", + " args.append(\"**\" + node.args.kwarg.arg)\n", + " ret = \"\"\n", + " if getattr(node, \"returns\", None):\n", + " try:\n", + " ret = f\" -> {ast.unparse(node.returns)}\"\n", + " except Exception:\n", + " pass\n", + " return f\"def {node.name}({', '.join(args)}){ret}:\"\n", + "\n", + " elif isinstance(node, ast.ClassDef):\n", + " return f\"class {node.name}:\"\n", + "\n", + " return \"\"\n", + "\n", + "\n", + "def context_snippet(src: str, node: ast.AST, max_lines: int = 60) -> str:\n", + " \"\"\"\n", + " Extract a small snippet of source code around a node for context.\n", + " This helps the LLM understand what the function/class does.\n", + " \"\"\"\n", + " lines = src.splitlines()\n", + " start = getattr(node, \"lineno\", 1) - 1\n", + " end = getattr(node, \"end_lineno\", start + 1)\n", + " snippet = lines[start:end]\n", + " if len(snippet) > max_lines:\n", + " snippet = snippet[:max_lines] + [\"# ... (truncated) ...\"]\n", + " return \"\\n\".join(snippet)\n", + "\n", + "\n", + "def find_missing_docstrings(src: str):\n", + " \"\"\"\n", + " Parse the Python source code and return a list of nodes\n", + " (module, class, function) that do NOT have docstrings.\n", + " \"\"\"\n", + " tree = ast.parse(src)\n", + " missing = []\n", + "\n", + " # Module-level docstring check\n", + " if ast.get_docstring(tree) is None:\n", + " missing.append((\"module\", tree))\n", + "\n", + " # Walk through the AST for classes and functions\n", + " for node in ast.walk(tree):\n", + " if isinstance(node, (ast.ClassDef, ast.FunctionDef, ast.AsyncFunctionDef)):\n", + " if ast.get_docstring(node) is None:\n", + " kind = \"class\" if isinstance(node, ast.ClassDef) else \"function\"\n", + " missing.append((kind, node))\n", + "\n", + " return missing\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ea69108f-e4ca-4326-89fe-97c5748c0e79", + "metadata": {}, + "outputs": [], + "source": [ + "## Quick Test ##\n", + "\n", + "code = '''\n", + "def add(x, y):\n", + " return x + y\n", + "\n", + "class Counter:\n", + " def inc(self):\n", + " self.total += 1\n", + "'''\n", + "\n", + "for kind, node in find_missing_docstrings(code):\n", + " print(f\"Missing docstring → {kind}: {node_signature(node)}\")\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "id": "00d65b96-e65d-4e11-89be-06f265a5f2e3", + "metadata": {}, + "outputs": [], + "source": [ + "# ================================================================\n", + "# Insert Generated Docstrings into Code\n", + "# ================================================================\n", + "import difflib\n", + "import textwrap\n", + "\n", + "def insert_docstring(src: str, node: ast.AST, docstring: str) -> str:\n", + " \"\"\"\n", + " Insert a generated docstring inside a function/class node.\n", + " Keeps indentation consistent with the original code.\n", + " \"\"\"\n", + " lines = src.splitlines()\n", + " if not hasattr(node, \"body\") or not node.body:\n", + " return src # nothing to insert into\n", + "\n", + " start_idx = node.body[0].lineno - 1\n", + " indent = re.match(r\"\\s*\", lines[start_idx]).group(0)\n", + " ds_lines = textwrap.indent(f'\"\"\"{docstring.strip()}\"\"\"', indent).splitlines()\n", + "\n", + " new_lines = lines[:start_idx] + ds_lines + [\"\"] + lines[start_idx:]\n", + " return \"\\n\".join(new_lines)\n", + "\n", + "\n", + "def insert_module_docstring(src: str, docstring: str) -> str:\n", + " \"\"\"Insert a module-level docstring at the top of the file.\"\"\"\n", + " lines = src.splitlines()\n", + " ds_block = f'\"\"\"{docstring.strip()}\"\"\"\\n'\n", + " return ds_block + \"\\n\".join(lines)\n", + "\n", + "\n", + "def diff_text(a: str, b: str) -> str:\n", + " \"\"\"Show unified diff of original vs updated code.\"\"\"\n", + " return \"\".join(\n", + " difflib.unified_diff(\n", + " a.splitlines(keepends=True),\n", + " b.splitlines(keepends=True),\n", + " fromfile=\"original.py\",\n", + " tofile=\"updated.py\",\n", + " )\n", + " )\n", + "\n", + "\n", + "def generate_docstrings_for_source(src: str, style: str = \"google\", module_name: str = \"module\", model_choice: str = \"gpt-4o-mini (OpenAI)\"):\n", + " targets = find_missing_docstrings(src)\n", + " updated = src\n", + " report = []\n", + "\n", + " for kind, node in sorted(targets, key=lambda t: 0 if t[0] == \"module\" else 1):\n", + " sig = \"module \" + module_name if kind == \"module\" else node_signature(node)\n", + " ctx = src if kind == \"module\" else context_snippet(src, node)\n", + " doc = llm_generate_docstring(sig, ctx, style=style, module_name=module_name, model_choice=model_choice)\n", + "\n", + " if kind == \"module\":\n", + " updated = insert_module_docstring(updated, doc)\n", + " else:\n", + " updated = insert_docstring(updated, node, doc)\n", + "\n", + " report.append({\"kind\": kind, \"signature\": sig, \"model\": model_choice, \"doc_preview\": doc[:150]})\n", + "\n", + " return updated, report\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d00cf4b7-773d-49cb-8262-9d11d787ee10", + "metadata": {}, + "outputs": [], + "source": [ + "## Quick Test ##\n", + "new_code, report = generate_docstrings_for_source(code, style=\"google\", module_name=\"demo\")\n", + "\n", + "print(\"=== Generated Docstrings ===\")\n", + "for r in report:\n", + " print(f\"- {r['kind']}: {r['signature']}\")\n", + " print(\" \", r['doc_preview'])\n", + "print(\"\\n=== Updated Source ===\")\n", + "print(new_code)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "id": "b318db41-c05d-48ce-9990-b6f1a0577c68", + "metadata": {}, + "outputs": [], + "source": [ + "# ================================================================\n", + "# 📂 File-Based Workflow — preview or apply docstrings\n", + "# ================================================================\n", + "from pathlib import Path\n", + "import pandas as pd\n", + "\n", + "def process_file(path: str, style: str = \"google\", apply: bool = False) -> pd.DataFrame:\n", + " \"\"\"\n", + " Process a .py file: find missing docstrings, generate them via GPT,\n", + " and either preview the diff or apply the updates in place.\n", + " \"\"\"\n", + " p = Path(path)\n", + " src = p.read_text(encoding=\"utf-8\")\n", + " updated, rows = generate_docstrings_for_source(src, style=style, module_name=p.stem)\n", + "\n", + " if apply:\n", + " p.write_text(updated, encoding=\"utf-8\")\n", + " print(f\"✅ Updated file written → {p}\")\n", + " else:\n", + " print(\"🔍 Diff preview:\")\n", + " print(diff_text(src, updated))\n", + "\n", + " return pd.DataFrame(rows)\n", + "\n", + "# Example usage:\n", + "# df = process_file(\"my_script.py\", style=\"google\", apply=False) # preview\n", + "# df = process_file(\"my_script.py\", style=\"google\", apply=True) # overwrite with docstrings\n", + "# df\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "id": "8962cf0e-9255-475e-bbc1-21500be0cd78", + "metadata": {}, + "outputs": [], + "source": [ + "# ================================================================\n", + "# 📂 File-Based Workflow — preview or apply docstrings\n", + "# ================================================================\n", + "from pathlib import Path\n", + "import pandas as pd\n", + "\n", + "def process_file(path: str, style: str = \"google\", apply: bool = False) -> pd.DataFrame:\n", + " \"\"\"\n", + " Process a .py file: find missing docstrings, generate them via GPT,\n", + " and either preview the diff or apply the updates in place.\n", + " \"\"\"\n", + " p = Path(path)\n", + " src = p.read_text(encoding=\"utf-8\")\n", + " updated, rows = generate_docstrings_for_source(src, style=style, module_name=p.stem)\n", + "\n", + " if apply:\n", + " p.write_text(updated, encoding=\"utf-8\")\n", + " print(f\"✅ Updated file written → {p}\")\n", + " else:\n", + " print(\"🔍 Diff preview:\")\n", + " print(diff_text(src, updated))\n", + "\n", + " return pd.DataFrame(rows)\n", + "\n", + "# Example usage:\n", + "# df = process_file(\"my_script.py\", style=\"google\", apply=False) # preview\n", + "# df = process_file(\"my_script.py\", style=\"google\", apply=True) # overwrite with docstrings\n", + "# df\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b0b0f852-982f-4918-9b5d-89880cc12003", + "metadata": {}, + "outputs": [], + "source": [ + "# ================================================================\n", + "# 🎨 Enhanced Gradio Interface with Model Selector\n", + "# ================================================================\n", + "import gradio as gr\n", + "\n", + "def gradio_generate(code_text: str, style: str, model_choice: str):\n", + " \"\"\"Wrapper for Gradio — generates docstrings using selected model.\"\"\"\n", + " if not code_text.strip():\n", + " return \"⚠️ Please paste some Python code first.\"\n", + " try:\n", + " updated, _ = generate_docstrings_for_source(\n", + " code_text, style=style, module_name=\"gradio_snippet\", model_choice=model_choice\n", + " )\n", + " return updated\n", + " except Exception as e:\n", + " return f\"❌ Error: {e}\"\n", + "\n", + "with gr.Blocks(theme=gr.themes.Soft()) as doc_ui:\n", + " gr.Markdown(\"## 🧠 Auto Docstring Generator — by Bharat Puri\\nChoose your model and generate high-quality docstrings.\")\n", + "\n", + " with gr.Row():\n", + " code_input = gr.Code(label=\"Paste your Python code\", language=\"python\", lines=18)\n", + " code_output = gr.Code(label=\"Generated code with docstrings\", language=\"python\", lines=18)\n", + "\n", + " with gr.Row():\n", + " style_choice = gr.Radio([\"google\"], value=\"google\", label=\"Docstring Style\")\n", + " model_choice = gr.Dropdown(\n", + " list(MODEL_REGISTRY.keys()),\n", + " value=\"gpt-4o-mini (OpenAI)\",\n", + " label=\"Select Model\",\n", + " )\n", + "\n", + " generate_btn = gr.Button(\"🚀 Generate Docstrings\")\n", + " generate_btn.click(\n", + " fn=gradio_generate,\n", + " inputs=[code_input, style_choice, model_choice],\n", + " outputs=[code_output],\n", + " )\n", + "\n", + "doc_ui.launch(share=False)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5e6d6720-de8e-4cbb-be9f-82bac3dcc71a", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.14" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/week4/community-contributions/python_to_cpp_code_translator/examples/calculator.py b/week4/community-contributions/python_to_cpp_code_translator/examples/calculator.py new file mode 100644 index 0000000..35af2d7 --- /dev/null +++ b/week4/community-contributions/python_to_cpp_code_translator/examples/calculator.py @@ -0,0 +1,190 @@ +""" +Simple calculator class with history tracking. +""" + +import math +from typing import List, Union + +class Calculator: + """A simple calculator with history tracking.""" + + def __init__(self): + """Initialize calculator with empty history.""" + self.history: List[str] = [] + self.memory: float = 0.0 + + def add(self, a: float, b: float) -> float: + """Add two numbers.""" + result = a + b + self.history.append(f"{a} + {b} = {result}") + return result + + def subtract(self, a: float, b: float) -> float: + """Subtract b from a.""" + result = a - b + self.history.append(f"{a} - {b} = {result}") + return result + + def multiply(self, a: float, b: float) -> float: + """Multiply two numbers.""" + result = a * b + self.history.append(f"{a} * {b} = {result}") + return result + + def divide(self, a: float, b: float) -> float: + """Divide a by b.""" + if b == 0: + raise ValueError("Cannot divide by zero") + result = a / b + self.history.append(f"{a} / {b} = {result}") + return result + + def power(self, base: float, exponent: float) -> float: + """Calculate base raised to the power of exponent.""" + result = base ** exponent + self.history.append(f"{base} ^ {exponent} = {result}") + return result + + def square_root(self, number: float) -> float: + """Calculate square root of a number.""" + if number < 0: + raise ValueError("Cannot calculate square root of negative number") + result = math.sqrt(number) + self.history.append(f"√{number} = {result}") + return result + + def factorial(self, n: int) -> int: + """Calculate factorial of n.""" + if n < 0: + raise ValueError("Factorial is not defined for negative numbers") + if n == 0 or n == 1: + return 1 + + result = 1 + for i in range(2, n + 1): + result *= i + + self.history.append(f"{n}! = {result}") + return result + + def memory_store(self, value: float) -> None: + """Store value in memory.""" + self.memory = value + self.history.append(f"Memory stored: {value}") + + def memory_recall(self) -> float: + """Recall value from memory.""" + self.history.append(f"Memory recalled: {self.memory}") + return self.memory + + def memory_clear(self) -> None: + """Clear memory.""" + self.memory = 0.0 + self.history.append("Memory cleared") + + def get_history(self) -> List[str]: + """Get calculation history.""" + return self.history.copy() + + def clear_history(self) -> None: + """Clear calculation history.""" + self.history.clear() + + def get_last_result(self) -> Union[float, None]: + """Get the result of the last calculation.""" + if not self.history: + return None + + last_entry = self.history[-1] + # Extract result from history entry + if "=" in last_entry: + return float(last_entry.split("=")[-1].strip()) + return None + +class ScientificCalculator(Calculator): + """Extended calculator with scientific functions.""" + + def sine(self, angle: float) -> float: + """Calculate sine of angle in radians.""" + result = math.sin(angle) + self.history.append(f"sin({angle}) = {result}") + return result + + def cosine(self, angle: float) -> float: + """Calculate cosine of angle in radians.""" + result = math.cos(angle) + self.history.append(f"cos({angle}) = {result}") + return result + + def tangent(self, angle: float) -> float: + """Calculate tangent of angle in radians.""" + result = math.tan(angle) + self.history.append(f"tan({angle}) = {result}") + return result + + def logarithm(self, number: float, base: float = math.e) -> float: + """Calculate logarithm of number with given base.""" + if number <= 0: + raise ValueError("Logarithm is not defined for non-positive numbers") + if base <= 0 or base == 1: + raise ValueError("Logarithm base must be positive and not equal to 1") + + result = math.log(number, base) + self.history.append(f"log_{base}({number}) = {result}") + return result + + def degrees_to_radians(self, degrees: float) -> float: + """Convert degrees to radians.""" + return degrees * math.pi / 180 + + def radians_to_degrees(self, radians: float) -> float: + """Convert radians to degrees.""" + return radians * 180 / math.pi + +def main(): + """Main function to demonstrate calculator functionality.""" + print("Calculator Demo") + print("=" * 30) + + # Basic calculator + calc = Calculator() + + print("Basic Calculator Operations:") + print(f"5 + 3 = {calc.add(5, 3)}") + print(f"10 - 4 = {calc.subtract(10, 4)}") + print(f"6 * 7 = {calc.multiply(6, 7)}") + print(f"15 / 3 = {calc.divide(15, 3)}") + print(f"2 ^ 8 = {calc.power(2, 8)}") + print(f"√64 = {calc.square_root(64)}") + print(f"5! = {calc.factorial(5)}") + + print(f"\nCalculation History:") + for entry in calc.get_history(): + print(f" {entry}") + + # Scientific calculator + print("\n" + "=" * 30) + print("Scientific Calculator Operations:") + + sci_calc = ScientificCalculator() + + # Convert degrees to radians for trigonometric functions + angle_deg = 45 + angle_rad = sci_calc.degrees_to_radians(angle_deg) + + print(f"sin({angle_deg}°) = {sci_calc.sine(angle_rad):.4f}") + print(f"cos({angle_deg}°) = {sci_calc.cosine(angle_rad):.4f}") + print(f"tan({angle_deg}°) = {sci_calc.tangent(angle_rad):.4f}") + print(f"ln(10) = {sci_calc.logarithm(10):.4f}") + print(f"log₁₀(100) = {sci_calc.logarithm(100, 10):.4f}") + + print(f"\nScientific Calculator History:") + for entry in sci_calc.get_history(): + print(f" {entry}") + +if __name__ == "__main__": + main() + + + + diff --git a/week4/community-contributions/python_to_cpp_code_translator/examples/fibonacci.py b/week4/community-contributions/python_to_cpp_code_translator/examples/fibonacci.py new file mode 100644 index 0000000..6a41a83 --- /dev/null +++ b/week4/community-contributions/python_to_cpp_code_translator/examples/fibonacci.py @@ -0,0 +1,64 @@ +""" +Fibonacci sequence implementation in Python. +""" + +def fibonacci(n): + """Calculate the nth Fibonacci number using recursion.""" + if n <= 1: + return n + return fibonacci(n-1) + fibonacci(n-2) + +def fibonacci_iterative(n): + """Calculate the nth Fibonacci number using iteration.""" + if n <= 1: + return n + + a, b = 0, 1 + for _ in range(2, n + 1): + a, b = b, a + b + return b + +def fibonacci_sequence(count): + """Generate a sequence of Fibonacci numbers.""" + sequence = [] + for i in range(count): + sequence.append(fibonacci(i)) + return sequence + +def main(): + """Main function to demonstrate Fibonacci calculations.""" + print("Fibonacci Sequence Demo") + print("=" * 30) + + # Calculate first 10 Fibonacci numbers + for i in range(10): + result = fibonacci(i) + print(f"fibonacci({i}) = {result}") + + print("\nFirst 15 Fibonacci numbers:") + sequence = fibonacci_sequence(15) + print(sequence) + + # Performance comparison + import time + + n = 30 + print(f"\nPerformance comparison for fibonacci({n}):") + + start_time = time.time() + recursive_result = fibonacci(n) + recursive_time = time.time() - start_time + + start_time = time.time() + iterative_result = fibonacci_iterative(n) + iterative_time = time.time() - start_time + + print(f"Recursive: {recursive_result} (took {recursive_time:.4f}s)") + print(f"Iterative: {iterative_result} (took {iterative_time:.4f}s)") + +if __name__ == "__main__": + main() + + + + diff --git a/week4/community-contributions/python_to_cpp_code_translator/examples/sorting_algorithms.py b/week4/community-contributions/python_to_cpp_code_translator/examples/sorting_algorithms.py new file mode 100644 index 0000000..4200070 --- /dev/null +++ b/week4/community-contributions/python_to_cpp_code_translator/examples/sorting_algorithms.py @@ -0,0 +1,150 @@ +""" +Various sorting algorithms implemented in Python. +""" + +import random +import time +from typing import List + +def bubble_sort(arr: List[int]) -> List[int]: + """Sort array using bubble sort algorithm.""" + n = len(arr) + arr = arr.copy() # Don't modify original array + + for i in range(n): + for j in range(0, n - i - 1): + if arr[j] > arr[j + 1]: + arr[j], arr[j + 1] = arr[j + 1], arr[j] + + return arr + +def selection_sort(arr: List[int]) -> List[int]: + """Sort array using selection sort algorithm.""" + n = len(arr) + arr = arr.copy() + + for i in range(n): + min_idx = i + for j in range(i + 1, n): + if arr[j] < arr[min_idx]: + min_idx = j + arr[i], arr[min_idx] = arr[min_idx], arr[i] + + return arr + +def insertion_sort(arr: List[int]) -> List[int]: + """Sort array using insertion sort algorithm.""" + arr = arr.copy() + + for i in range(1, len(arr)): + key = arr[i] + j = i - 1 + while j >= 0 and arr[j] > key: + arr[j + 1] = arr[j] + j -= 1 + arr[j + 1] = key + + return arr + +def quick_sort(arr: List[int]) -> List[int]: + """Sort array using quick sort algorithm.""" + if len(arr) <= 1: + return arr + + pivot = arr[len(arr) // 2] + left = [x for x in arr if x < pivot] + middle = [x for x in arr if x == pivot] + right = [x for x in arr if x > pivot] + + return quick_sort(left) + middle + quick_sort(right) + +def merge_sort(arr: List[int]) -> List[int]: + """Sort array using merge sort algorithm.""" + if len(arr) <= 1: + return arr + + mid = len(arr) // 2 + left = merge_sort(arr[:mid]) + right = merge_sort(arr[mid:]) + + return merge(left, right) + +def merge(left: List[int], right: List[int]) -> List[int]: + """Merge two sorted arrays.""" + result = [] + i = j = 0 + + while i < len(left) and j < len(right): + if left[i] <= right[j]: + result.append(left[i]) + i += 1 + else: + result.append(right[j]) + j += 1 + + result.extend(left[i:]) + result.extend(right[j:]) + return result + +def benchmark_sorting_algorithms(): + """Benchmark different sorting algorithms.""" + sizes = [100, 500, 1000, 2000] + algorithms = { + "Bubble Sort": bubble_sort, + "Selection Sort": selection_sort, + "Insertion Sort": insertion_sort, + "Quick Sort": quick_sort, + "Merge Sort": merge_sort + } + + print("Sorting Algorithm Benchmark") + print("=" * 50) + + for size in sizes: + print(f"\nArray size: {size}") + print("-" * 30) + + # Generate random array + test_array = [random.randint(1, 1000) for _ in range(size)] + + for name, algorithm in algorithms.items(): + start_time = time.time() + sorted_array = algorithm(test_array) + end_time = time.time() + + # Verify sorting is correct + is_sorted = all(sorted_array[i] <= sorted_array[i+1] for i in range(len(sorted_array)-1)) + + print(f"{name:15}: {end_time - start_time:.4f}s {'✓' if is_sorted else '✗'}") + +def main(): + """Main function to demonstrate sorting algorithms.""" + print("Sorting Algorithms Demo") + print("=" * 30) + + # Test with small array + test_array = [64, 34, 25, 12, 22, 11, 90] + print(f"Original array: {test_array}") + + algorithms = { + "Bubble Sort": bubble_sort, + "Selection Sort": selection_sort, + "Insertion Sort": insertion_sort, + "Quick Sort": quick_sort, + "Merge Sort": merge_sort + } + + for name, algorithm in algorithms.items(): + sorted_array = algorithm(test_array) + print(f"{name}: {sorted_array}") + + # Run benchmark + print("\n" + "=" * 50) + benchmark_sorting_algorithms() + +if __name__ == "__main__": + main() + + + + diff --git a/week4/community-contributions/python_to_cpp_code_translator/python_code_translator.ipynb b/week4/community-contributions/python_to_cpp_code_translator/python_code_translator.ipynb new file mode 100644 index 0000000..d97e14b --- /dev/null +++ b/week4/community-contributions/python_to_cpp_code_translator/python_code_translator.ipynb @@ -0,0 +1,1280 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# 🚀 Code Translator from Python to C++\n", + "\n", + "**Multi-LLM Python to C++ Code Translator with Compilation Testing and Quality Analysis**\n", + "\n", + "This notebook demonstrates a comprehensive AI-powered code translation system that:\n", + "- Translates Python code to C++ using multiple LLM models (GPT-4o, Claude 3.5 Sonnet, Gemini 2.0 Flash)\n", + "- Automatically compiles and tests generated C++ code\n", + "- Performs quality analysis and performance benchmarking\n", + "- Compares translation results across different AI models\n", + "\n", + "## 🎯 Key Features\n", + "\n", + "- **Multi-LLM Support**: Compare translations from OpenAI, Anthropic, and Google\n", + "- **C++ Compilation**: Automatic compilation and execution testing\n", + "- **Quality Analysis**: Code quality metrics and performance benchmarking\n", + "- **Interactive Interface**: Easy-to-use notebook interface\n", + "- **Comprehensive Testing**: Full test suite for validation\n", + "\n", + "## 📋 Table of Contents\n", + "\n", + "1. [Setup and Installation](#setup)\n", + "2. [LLM Client Implementation](#llm-clients)\n", + "3. [C++ Compiler and Testing](#compiler)\n", + "4. [Core Translation Logic](#translator)\n", + "5. [Quality Analysis](#quality)\n", + "6. [Interactive Examples](#examples)\n", + "7. [Performance Benchmarking](#benchmarking)\n", + "8. [Testing and Validation](#testing)\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 1. Setup and Installation\n", + "\n", + "First, let's install the required dependencies and set up the environment.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Install required packages\n", + "!uv add openai anthropic google-generativeai gradio python-dotenv pydantic requests psutil memory-profiler pytest black flake8 mypy\n", + "#For those working with pip, you can use the following command:\n", + "#!pip install openai anthropic google-generativeai gradio python-dotenv pydantic requests psutil memory-profiler pytest black flake8 mypy\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Import required libraries\n", + "import os\n", + "import sys\n", + "import json\n", + "import time\n", + "import subprocess\n", + "import tempfile\n", + "import psutil\n", + "import re\n", + "from typing import Dict, List, Optional, Tuple, Any, Union\n", + "from dataclasses import dataclass, asdict\n", + "from pathlib import Path\n", + "\n", + "# LLM libraries\n", + "import openai\n", + "import anthropic\n", + "import google.generativeai as genai\n", + "from dotenv import load_dotenv\n", + "\n", + "# Load environment variables\n", + "load_dotenv()\n", + "\n", + "print(\"✅ All libraries imported successfully!\")\n", + "print(f\"Python version: {sys.version}\")\n", + "print(f\"Working directory: {os.getcwd()}\")\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 2. LLM Client Implementation\n", + "\n", + "Let's implement the LLM clients for OpenAI GPT, Anthropic Claude, and Google Gemini.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Data classes for translation results\n", + "@dataclass\n", + "class TranslationResult:\n", + " \"\"\"Result of a code translation.\"\"\"\n", + " source_code: str\n", + " translated_code: str\n", + " model_name: str\n", + " success: bool\n", + " error_message: Optional[str] = None\n", + " translation_time: float = 0.0\n", + " token_usage: Optional[Dict] = None\n", + "\n", + "@dataclass\n", + "class CompilationResult:\n", + " \"\"\"Result of C++ compilation.\"\"\"\n", + " success: bool\n", + " executable_path: Optional[str] = None\n", + " error_message: Optional[str] = None\n", + " compilation_time: float = 0.0\n", + " warnings: List[str] = None\n", + "\n", + "@dataclass\n", + "class ExecutionResult:\n", + " \"\"\"Result of C++ code execution.\"\"\"\n", + " success: bool\n", + " output: str = \"\"\n", + " error_message: Optional[str] = None\n", + " execution_time: float = 0.0\n", + " memory_usage: float = 0.0\n", + " exit_code: int = 0\n", + "\n", + "@dataclass\n", + "class PerformanceMetrics:\n", + " \"\"\"Performance metrics for C++ code.\"\"\"\n", + " execution_time: float\n", + " memory_usage: float\n", + " cpu_usage: float\n", + " code_size: int\n", + " compilation_time: float\n", + "\n", + "print(\"✅ Data classes defined successfully!\")\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# OpenAI GPT Client\n", + "class OpenAIClient:\n", + " \"\"\"OpenAI GPT client for code translation.\"\"\"\n", + " \n", + " def __init__(self, api_key: str):\n", + " self.api_key = api_key\n", + " self.client = openai.OpenAI(api_key=api_key)\n", + " \n", + " def translate_python_to_cpp(self, python_code: str, context: str = \"\") -> TranslationResult:\n", + " \"\"\"Translate Python code to C++ using GPT-4o.\"\"\"\n", + " start_time = time.time()\n", + " \n", + " try:\n", + " system_prompt = \"\"\"You are an expert Python to C++ translator. \n", + " Convert the given Python code to efficient, modern C++ code.\n", + " \n", + " Requirements:\n", + " - Use modern C++17/20 features\n", + " - Include proper headers\n", + " - Add comprehensive error handling\n", + " - Optimize for performance\n", + " - Include detailed comments\n", + " - Follow C++ best practices\n", + " \n", + " Return ONLY the C++ code, no explanations.\"\"\"\n", + " \n", + " user_prompt = f\"\"\"Translate this Python code to C++:\n", + "\n", + "Context: {context}\n", + "\n", + "Python Code:\n", + "```python\n", + "{python_code}\n", + "```\n", + "\n", + "C++ Translation:\"\"\"\n", + " \n", + " response = self.client.chat.completions.create(\n", + " model=\"gpt-4o\",\n", + " messages=[\n", + " {\"role\": \"system\", \"content\": system_prompt},\n", + " {\"role\": \"user\", \"content\": user_prompt}\n", + " ],\n", + " temperature=0.1,\n", + " max_tokens=4000\n", + " )\n", + " \n", + " translated_code = response.choices[0].message.content.strip()\n", + " translation_time = time.time() - start_time\n", + " \n", + " return TranslationResult(\n", + " source_code=python_code,\n", + " translated_code=translated_code,\n", + " model_name=\"GPT-4o\",\n", + " success=True,\n", + " translation_time=translation_time,\n", + " token_usage={\n", + " \"prompt_tokens\": response.usage.prompt_tokens,\n", + " \"completion_tokens\": response.usage.completion_tokens,\n", + " \"total_tokens\": response.usage.total_tokens\n", + " }\n", + " )\n", + " \n", + " except Exception as e:\n", + " return TranslationResult(\n", + " source_code=python_code,\n", + " translated_code=\"\",\n", + " model_name=\"GPT-4o\",\n", + " success=False,\n", + " error_message=str(e),\n", + " translation_time=time.time() - start_time\n", + " )\n", + "\n", + "print(\"✅ OpenAI client implemented!\")\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Anthropic Claude Client\n", + "class ClaudeClient:\n", + " \"\"\"Anthropic Claude client for code translation.\"\"\"\n", + " \n", + " def __init__(self, api_key: str):\n", + " self.api_key = api_key\n", + " self.client = anthropic.Anthropic(api_key=api_key)\n", + " \n", + " def translate_python_to_cpp(self, python_code: str, context: str = \"\") -> TranslationResult:\n", + " \"\"\"Translate Python code to C++ using Claude 3.5 Sonnet.\"\"\"\n", + " start_time = time.time()\n", + " \n", + " try:\n", + " system_prompt = \"\"\"You are an expert Python to C++ translator. \n", + " Convert the given Python code to efficient, modern C++ code.\n", + " \n", + " Requirements:\n", + " - Use modern C++17/20 features\n", + " - Include proper headers\n", + " - Add comprehensive error handling\n", + " - Optimize for performance\n", + " - Include detailed comments\n", + " - Follow C++ best practices\n", + " \n", + " Return ONLY the C++ code, no explanations.\"\"\"\n", + " \n", + " user_prompt = f\"\"\"Translate this Python code to C++:\n", + "\n", + "Context: {context}\n", + "\n", + "Python Code:\n", + "```python\n", + "{python_code}\n", + "```\n", + "\n", + "C++ Translation:\"\"\"\n", + " \n", + " response = self.client.messages.create(\n", + " model=\"claude-sonnet-4-20250514\",\n", + " max_tokens=4000,\n", + " temperature=0.1,\n", + " system=system_prompt,\n", + " messages=[\n", + " {\"role\": \"user\", \"content\": user_prompt}\n", + " ]\n", + " )\n", + " \n", + " translated_code = response.content[0].text.strip()\n", + " translation_time = time.time() - start_time\n", + " \n", + " return TranslationResult(\n", + " source_code=python_code,\n", + " translated_code=translated_code,\n", + " model_name=\"Claude-3.5-Sonnet\",\n", + " success=True,\n", + " translation_time=translation_time,\n", + " token_usage={\n", + " \"input_tokens\": response.usage.input_tokens,\n", + " \"output_tokens\": response.usage.output_tokens\n", + " }\n", + " )\n", + " \n", + " except Exception as e:\n", + " return TranslationResult(\n", + " source_code=python_code,\n", + " translated_code=\"\",\n", + " model_name=\"Claude-3.5-Sonnet\",\n", + " success=False,\n", + " error_message=str(e),\n", + " translation_time=time.time() - start_time\n", + " )\n", + "\n", + "print(\"✅ Claude client implemented!\")\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Google Gemini Client\n", + "class GeminiClient:\n", + " \"\"\"Google Gemini client for code translation.\"\"\"\n", + " \n", + " def __init__(self, api_key: str):\n", + " self.api_key = api_key\n", + " genai.configure(api_key=api_key)\n", + " self.client = genai.GenerativeModel('gemini-2.0-flash-exp')\n", + " \n", + " def translate_python_to_cpp(self, python_code: str, context: str = \"\") -> TranslationResult:\n", + " \"\"\"Translate Python code to C++ using Gemini 2.0 Flash.\"\"\"\n", + " start_time = time.time()\n", + " \n", + " try:\n", + " prompt = f\"\"\"You are an expert Python to C++ translator. \n", + " Convert the given Python code to efficient, modern C++ code.\n", + " \n", + " Requirements:\n", + " - Use modern C++17/20 features\n", + " - Include proper headers\n", + " - Add comprehensive error handling\n", + " - Optimize for performance\n", + " - Include detailed comments\n", + " - Follow C++ best practices\n", + " \n", + " Context: {context}\n", + " \n", + " Python Code:\n", + " ```python\n", + " {python_code}\n", + " ```\n", + " \n", + " Return ONLY the C++ code, no explanations.\"\"\"\n", + " \n", + " response = self.client.generate_content(\n", + " prompt,\n", + " generation_config=genai.types.GenerationConfig(\n", + " temperature=0.1,\n", + " max_output_tokens=4000\n", + " )\n", + " )\n", + " \n", + " translated_code = response.text.strip()\n", + " translation_time = time.time() - start_time\n", + " \n", + " return TranslationResult(\n", + " source_code=python_code,\n", + " translated_code=translated_code,\n", + " model_name=\"Gemini-2.0-Flash\",\n", + " success=True,\n", + " translation_time=translation_time\n", + " )\n", + " \n", + " except Exception as e:\n", + " return TranslationResult(\n", + " source_code=python_code,\n", + " translated_code=\"\",\n", + " model_name=\"Gemini-2.0-Flash\",\n", + " success=False,\n", + " error_message=str(e),\n", + " translation_time=time.time() - start_time\n", + " )\n", + "\n", + "print(\"✅ Gemini client implemented!\")\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# LLM Client Manager\n", + "class LLMClientManager:\n", + " \"\"\"Manages multiple LLM clients for code translation.\"\"\"\n", + " \n", + " def __init__(self):\n", + " self.clients = {}\n", + " self._initialize_clients()\n", + " \n", + " def _initialize_clients(self):\n", + " \"\"\"Initialize available LLM clients.\"\"\"\n", + " # OpenAI\n", + " openai_key = os.getenv('OPENAI_API_KEY')\n", + " if openai_key:\n", + " self.clients['gpt'] = OpenAIClient(openai_key)\n", + " \n", + " # Anthropic Claude\n", + " claude_key = os.getenv('ANTHROPIC_API_KEY')\n", + " if claude_key:\n", + " self.clients['claude'] = ClaudeClient(claude_key)\n", + " \n", + " # Google Gemini\n", + " gemini_key = os.getenv('GOOGLE_API_KEY')\n", + " if gemini_key:\n", + " self.clients['gemini'] = GeminiClient(gemini_key)\n", + " \n", + " def get_available_models(self) -> List[str]:\n", + " \"\"\"Get list of available model names.\"\"\"\n", + " return list(self.clients.keys())\n", + " \n", + " def translate_with_all_models(self, python_code: str, context: str = \"\") -> Dict[str, TranslationResult]:\n", + " \"\"\"Translate code using all available models.\"\"\"\n", + " results = {}\n", + " \n", + " for model_name, client in self.clients.items():\n", + " try:\n", + " result = client.translate_python_to_cpp(python_code, context)\n", + " results[model_name] = result\n", + " except Exception as e:\n", + " results[model_name] = TranslationResult(\n", + " source_code=python_code,\n", + " translated_code=\"\",\n", + " model_name=model_name,\n", + " success=False,\n", + " error_message=str(e)\n", + " )\n", + " \n", + " return results\n", + " \n", + " def translate_with_model(self, model_name: str, python_code: str, context: str = \"\") -> TranslationResult:\n", + " \"\"\"Translate code using a specific model.\"\"\"\n", + " if model_name not in self.clients:\n", + " raise ValueError(f\"Model {model_name} not available. Available models: {list(self.clients.keys())}\")\n", + " \n", + " return self.clients[model_name].translate_python_to_cpp(python_code, context)\n", + "\n", + "# Initialize LLM manager\n", + "llm_manager = LLMClientManager()\n", + "available_models = llm_manager.get_available_models()\n", + "\n", + "print(f\"✅ LLM Client Manager initialized!\")\n", + "print(f\"Available models: {available_models}\")\n", + "\n", + "if not available_models:\n", + " print(\"⚠️ No LLM models available. Please check your API keys:\")\n", + " print(\" - OPENAI_API_KEY\")\n", + " print(\" - ANTHROPIC_API_KEY\") \n", + " print(\" - GOOGLE_API_KEY\")\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 3. C++ Compiler and Testing\n", + "\n", + "Now let's implement the C++ compilation and testing functionality.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# C++ Compiler Implementation\n", + "class CppCompiler:\n", + " \"\"\"Handles C++ compilation and testing.\"\"\"\n", + " \n", + " def __init__(self, compiler_path: str = \"g++\", optimization_level: str = \"-O2\"):\n", + " self.compiler_path = compiler_path\n", + " self.optimization_level = optimization_level\n", + " self.temp_dir = None\n", + " \n", + " def __enter__(self):\n", + " \"\"\"Context manager entry.\"\"\"\n", + " self.temp_dir = tempfile.mkdtemp(prefix=\"cpp_translator_\")\n", + " return self\n", + " \n", + " def __exit__(self, exc_type, exc_val, exc_tb):\n", + " \"\"\"Context manager exit - cleanup temp files.\"\"\"\n", + " if self.temp_dir and os.path.exists(self.temp_dir):\n", + " import shutil\n", + " shutil.rmtree(self.temp_dir, ignore_errors=True)\n", + " \n", + " def _write_cpp_file(self, cpp_code: str, filename: str = \"main.cpp\") -> str:\n", + " \"\"\"Write C++ code to a temporary file.\"\"\"\n", + " if not self.temp_dir:\n", + " raise RuntimeError(\"Compiler not initialized. Use as context manager.\")\n", + " \n", + " file_path = os.path.join(self.temp_dir, filename)\n", + " with open(file_path, 'w', encoding='utf-8') as f:\n", + " f.write(cpp_code)\n", + " return file_path\n", + " \n", + " def _add_standard_headers(self, cpp_code: str) -> str:\n", + " \"\"\"Add standard C++ headers if not present.\"\"\"\n", + " if \"#include\" not in cpp_code:\n", + " headers = [\n", + " \"#include \",\n", + " \"#include \",\n", + " \"#include \",\n", + " \"#include \",\n", + " \"#include \",\n", + " \"#include \",\n", + " \"#include \",\n", + " \"#include \"\n", + " ]\n", + " cpp_code = \"\\n\".join(headers) + \"\\n\\n\" + cpp_code\n", + " \n", + " return cpp_code\n", + " \n", + " def _add_main_function_if_needed(self, cpp_code: str) -> str:\n", + " \"\"\"Add main function if not present.\"\"\"\n", + " if \"int main(\" not in cpp_code and \"void main(\" not in cpp_code:\n", + " main_code = \"\"\"\n", + "int main() {\n", + " try {\n", + " // Your code will be executed here\n", + " return 0;\n", + " } catch (const std::exception& e) {\n", + " std::cerr << \"Error: \" << e.what() << std::endl;\n", + " return 1;\n", + " }\n", + "}\"\"\"\n", + " cpp_code += main_code\n", + " \n", + " return cpp_code\n", + " \n", + " def compile_cpp(self, cpp_code: str, output_name: str = \"main\") -> CompilationResult:\n", + " \"\"\"Compile C++ code to executable.\"\"\"\n", + " start_time = time.time()\n", + " \n", + " try:\n", + " # Preprocess the code\n", + " cpp_code = self._add_standard_headers(cpp_code)\n", + " cpp_code = self._add_main_function_if_needed(cpp_code)\n", + " \n", + " # Write to temporary file\n", + " cpp_file = self._write_cpp_file(cpp_code)\n", + " exe_path = os.path.join(self.temp_dir, output_name)\n", + " \n", + " # Compilation command\n", + " cmd = [\n", + " self.compiler_path,\n", + " self.optimization_level,\n", + " \"-std=c++17\",\n", + " \"-Wall\",\n", + " \"-Wextra\",\n", + " cpp_file,\n", + " \"-o\", exe_path\n", + " ]\n", + " \n", + " # Compile\n", + " result = subprocess.run(\n", + " cmd,\n", + " capture_output=True,\n", + " text=True,\n", + " timeout=30\n", + " )\n", + " \n", + " compilation_time = time.time() - start_time\n", + " \n", + " if result.returncode == 0:\n", + " return CompilationResult(\n", + " success=True,\n", + " executable_path=exe_path,\n", + " compilation_time=compilation_time,\n", + " warnings=self._extract_warnings(result.stderr)\n", + " )\n", + " else:\n", + " return CompilationResult(\n", + " success=False,\n", + " error_message=result.stderr,\n", + " compilation_time=compilation_time\n", + " )\n", + " \n", + " except subprocess.TimeoutExpired:\n", + " return CompilationResult(\n", + " success=False,\n", + " error_message=\"Compilation timeout\",\n", + " compilation_time=time.time() - start_time\n", + " )\n", + " except Exception as e:\n", + " return CompilationResult(\n", + " success=False,\n", + " error_message=str(e),\n", + " compilation_time=time.time() - start_time\n", + " )\n", + " \n", + " def _extract_warnings(self, stderr: str) -> List[str]:\n", + " \"\"\"Extract warnings from compiler output.\"\"\"\n", + " warnings = []\n", + " for line in stderr.split('\\n'):\n", + " if 'warning:' in line.lower():\n", + " warnings.append(line.strip())\n", + " return warnings\n", + " \n", + " def execute_cpp(self, executable_path: str, input_data: str = \"\", timeout: int = 10) -> ExecutionResult:\n", + " \"\"\"Execute compiled C++ code.\"\"\"\n", + " start_time = time.time()\n", + " \n", + " try:\n", + " # Start process\n", + " process = subprocess.Popen(\n", + " [executable_path],\n", + " stdin=subprocess.PIPE,\n", + " stdout=subprocess.PIPE,\n", + " stderr=subprocess.PIPE,\n", + " text=True\n", + " )\n", + " \n", + " # Monitor memory usage\n", + " memory_usage = 0.0\n", + " try:\n", + " ps_process = psutil.Process(process.pid)\n", + " memory_usage = ps_process.memory_info().rss / 1024 / 1024 # MB\n", + " except (psutil.NoSuchProcess, psutil.AccessDenied):\n", + " pass\n", + " \n", + " # Execute with timeout\n", + " stdout, stderr = process.communicate(input=input_data, timeout=timeout)\n", + " execution_time = time.time() - start_time\n", + " \n", + " return ExecutionResult(\n", + " success=process.returncode == 0,\n", + " output=stdout,\n", + " error_message=stderr if stderr else None,\n", + " execution_time=execution_time,\n", + " memory_usage=memory_usage,\n", + " exit_code=process.returncode\n", + " )\n", + " \n", + " except subprocess.TimeoutExpired:\n", + " process.kill()\n", + " return ExecutionResult(\n", + " success=False,\n", + " error_message=\"Execution timeout\",\n", + " execution_time=time.time() - start_time\n", + " )\n", + " except Exception as e:\n", + " return ExecutionResult(\n", + " success=False,\n", + " error_message=str(e),\n", + " execution_time=time.time() - start_time\n", + " )\n", + " \n", + " def compile_and_test(self, cpp_code: str, test_input: str = \"\") -> Tuple[CompilationResult, Optional[ExecutionResult]]:\n", + " \"\"\"Compile and test C++ code.\"\"\"\n", + " # Compile\n", + " compilation_result = self.compile_cpp(cpp_code)\n", + " \n", + " if not compilation_result.success:\n", + " return compilation_result, None\n", + " \n", + " # Execute\n", + " execution_result = self.execute_cpp(compilation_result.executable_path, test_input)\n", + " \n", + " return compilation_result, execution_result\n", + "\n", + "print(\"✅ C++ Compiler implemented!\")\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Code Quality Analyzer\n", + "class CodeQualityAnalyzer:\n", + " \"\"\"Analyzes code quality metrics.\"\"\"\n", + " \n", + " @staticmethod\n", + " def analyze_cpp_quality(cpp_code: str) -> Dict[str, Any]:\n", + " \"\"\"Analyze C++ code quality.\"\"\"\n", + " metrics = {\n", + " \"lines_of_code\": len(cpp_code.split('\\n')),\n", + " \"comment_ratio\": CodeQualityAnalyzer._calculate_comment_ratio(cpp_code),\n", + " \"function_count\": CodeQualityAnalyzer._count_functions(cpp_code),\n", + " \"class_count\": CodeQualityAnalyzer._count_classes(cpp_code),\n", + " \"complexity_score\": CodeQualityAnalyzer._calculate_complexity(cpp_code),\n", + " \"style_score\": CodeQualityAnalyzer._calculate_style_score(cpp_code),\n", + " \"error_handling\": CodeQualityAnalyzer._check_error_handling(cpp_code),\n", + " \"modern_cpp_features\": CodeQualityAnalyzer._check_modern_features(cpp_code)\n", + " }\n", + " \n", + " return metrics\n", + " \n", + " @staticmethod\n", + " def _calculate_comment_ratio(cpp_code: str) -> float:\n", + " \"\"\"Calculate ratio of commented lines.\"\"\"\n", + " lines = cpp_code.split('\\n')\n", + " comment_lines = sum(1 for line in lines if line.strip().startswith('//') or line.strip().startswith('/*'))\n", + " return comment_lines / len(lines) if lines else 0.0\n", + " \n", + " @staticmethod\n", + " def _count_functions(cpp_code: str) -> int:\n", + " \"\"\"Count function definitions.\"\"\"\n", + " pattern = r'\\w+\\s+\\w+\\s*\\([^)]*\\)\\s*\\{'\n", + " return len(re.findall(pattern, cpp_code))\n", + " \n", + " @staticmethod\n", + " def _count_classes(cpp_code: str) -> int:\n", + " \"\"\"Count class definitions.\"\"\"\n", + " pattern = r'class\\s+\\w+'\n", + " return len(re.findall(pattern, cpp_code))\n", + " \n", + " @staticmethod\n", + " def _calculate_complexity(cpp_code: str) -> int:\n", + " \"\"\"Calculate cyclomatic complexity.\"\"\"\n", + " complexity_keywords = ['if', 'else', 'while', 'for', 'switch', 'case', 'catch', '&&', '||']\n", + " complexity = 1 # Base complexity\n", + " \n", + " for keyword in complexity_keywords:\n", + " complexity += cpp_code.count(keyword)\n", + " \n", + " return complexity\n", + " \n", + " @staticmethod\n", + " def _calculate_style_score(cpp_code: str) -> float:\n", + " \"\"\"Calculate style score based on various factors.\"\"\"\n", + " score = 0.0\n", + " lines = cpp_code.split('\\n')\n", + " \n", + " # Check for consistent indentation\n", + " if all(line.startswith((' ', '\\t')) or not line.strip() for line in lines[1:]):\n", + " score += 0.2\n", + " \n", + " # Check for proper spacing\n", + " if re.search(r'\\w\\(\\w', cpp_code): # Functions with proper spacing\n", + " score += 0.2\n", + " \n", + " # Check for const correctness\n", + " if 'const' in cpp_code:\n", + " score += 0.2\n", + " \n", + " # Check for RAII usage\n", + " if 'std::unique_ptr' in cpp_code or 'std::shared_ptr' in cpp_code:\n", + " score += 0.2\n", + " \n", + " # Check for proper includes\n", + " if '#include' in cpp_code:\n", + " score += 0.2\n", + " \n", + " return min(score, 1.0)\n", + " \n", + " @staticmethod\n", + " def _check_error_handling(cpp_code: str) -> bool:\n", + " \"\"\"Check if code has proper error handling.\"\"\"\n", + " return 'try' in cpp_code and 'catch' in cpp_code\n", + " \n", + " @staticmethod\n", + " def _check_modern_features(cpp_code: str) -> List[str]:\n", + " \"\"\"Check for modern C++ features.\"\"\"\n", + " features = []\n", + " \n", + " if 'auto' in cpp_code:\n", + " features.append('auto')\n", + " if 'std::unique_ptr' in cpp_code:\n", + " features.append('smart_pointers')\n", + " if 'std::vector' in cpp_code:\n", + " features.append('stl_containers')\n", + " if 'lambda' in cpp_code or '[]' in cpp_code:\n", + " features.append('lambdas')\n", + " if 'std::thread' in cpp_code:\n", + " features.append('threading')\n", + " \n", + " return features\n", + "\n", + "print(\"✅ Code Quality Analyzer implemented!\")\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 4. Core Translation Logic\n", + "\n", + "Now let's implement the main translation logic that coordinates all components.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Translation Comparison Data Class\n", + "@dataclass\n", + "class TranslationComparison:\n", + " \"\"\"Comparison of translations across different models.\"\"\"\n", + " model_results: Dict[str, TranslationResult]\n", + " compilation_results: Dict[str, CompilationResult]\n", + " execution_results: Dict[str, ExecutionResult]\n", + " performance_metrics: Dict[str, PerformanceMetrics]\n", + " quality_scores: Dict[str, Dict[str, Any]]\n", + " best_model: Optional[str] = None\n", + " comparison_summary: Optional[str] = None\n", + "\n", + "# Main Code Translator\n", + "class CodeTranslator:\n", + " \"\"\"Main translator class that coordinates the entire translation process.\"\"\"\n", + " \n", + " def __init__(self):\n", + " self.llm_manager = LLMClientManager()\n", + " self.available_models = self.llm_manager.get_available_models()\n", + " \n", + " if not self.available_models:\n", + " print(\"⚠️ No LLM models available. Please check your API keys.\")\n", + " \n", + " def translate_python_to_cpp(self, python_code: str, context: str = \"\", \n", + " test_input: str = \"\", use_all_models: bool = True) -> TranslationComparison:\n", + " \"\"\"Translate Python code to C++ using available models.\"\"\"\n", + " \n", + " if use_all_models:\n", + " # Translate with all available models\n", + " translation_results = self.llm_manager.translate_with_all_models(python_code, context)\n", + " else:\n", + " # Use first available model\n", + " model_name = self.available_models[0]\n", + " result = self.llm_manager.translate_with_model(model_name, python_code, context)\n", + " translation_results = {model_name: result}\n", + " \n", + " # Compile and test each translation\n", + " compilation_results = {}\n", + " execution_results = {}\n", + " performance_metrics = {}\n", + " quality_scores = {}\n", + " \n", + " with CppCompiler() as compiler:\n", + " for model_name, translation_result in translation_results.items():\n", + " if not translation_result.success:\n", + " continue\n", + " \n", + " # Compile and test\n", + " comp_result, exec_result = compiler.compile_and_test(\n", + " translation_result.translated_code, \n", + " test_input\n", + " )\n", + " \n", + " compilation_results[model_name] = comp_result\n", + " if exec_result:\n", + " execution_results[model_name] = exec_result\n", + " \n", + " # Get performance metrics\n", + " perf_metrics = self._get_performance_metrics(compiler, translation_result.translated_code, test_input)\n", + " if perf_metrics:\n", + " performance_metrics[model_name] = perf_metrics\n", + " \n", + " # Analyze code quality\n", + " quality_scores[model_name] = CodeQualityAnalyzer.analyze_cpp_quality(\n", + " translation_result.translated_code\n", + " )\n", + " \n", + " # Determine best model\n", + " best_model = self._determine_best_model(\n", + " translation_results, compilation_results, execution_results, \n", + " performance_metrics, quality_scores\n", + " )\n", + " \n", + " # Generate comparison summary\n", + " comparison_summary = self._generate_comparison_summary(\n", + " translation_results, compilation_results, execution_results,\n", + " performance_metrics, quality_scores, best_model\n", + " )\n", + " \n", + " return TranslationComparison(\n", + " model_results=translation_results,\n", + " compilation_results=compilation_results,\n", + " execution_results=execution_results,\n", + " performance_metrics=performance_metrics,\n", + " quality_scores=quality_scores,\n", + " best_model=best_model,\n", + " comparison_summary=comparison_summary\n", + " )\n", + " \n", + " def _get_performance_metrics(self, compiler: CppCompiler, cpp_code: str, test_input: str = \"\") -> Optional[PerformanceMetrics]:\n", + " \"\"\"Get comprehensive performance metrics.\"\"\"\n", + " compilation_result, execution_result = compiler.compile_and_test(cpp_code, test_input)\n", + " \n", + " if not compilation_result.success or not execution_result or not execution_result.success:\n", + " return None\n", + " \n", + " # Get code size\n", + " cpp_file = compiler._write_cpp_file(cpp_code)\n", + " code_size = os.path.getsize(cpp_file)\n", + " \n", + " # Get executable size\n", + " exe_size = 0\n", + " if compilation_result.executable_path and os.path.exists(compilation_result.executable_path):\n", + " exe_size = os.path.getsize(compilation_result.executable_path)\n", + " \n", + " return PerformanceMetrics(\n", + " execution_time=execution_result.execution_time,\n", + " memory_usage=execution_result.memory_usage,\n", + " cpu_usage=0.0, # Would need more complex monitoring\n", + " code_size=code_size,\n", + " compilation_time=compilation_result.compilation_time\n", + " )\n", + " \n", + " def _determine_best_model(self, translation_results: Dict[str, TranslationResult],\n", + " compilation_results: Dict[str, CompilationResult],\n", + " execution_results: Dict[str, ExecutionResult],\n", + " performance_metrics: Dict[str, PerformanceMetrics],\n", + " quality_scores: Dict[str, Dict[str, Any]]) -> Optional[str]:\n", + " \"\"\"Determine the best model based on multiple criteria.\"\"\"\n", + " \n", + " scores = {}\n", + " \n", + " for model_name in translation_results.keys():\n", + " score = 0.0\n", + " \n", + " # Translation success (40% weight)\n", + " if translation_results[model_name].success:\n", + " score += 0.4\n", + " \n", + " # Compilation success (30% weight)\n", + " if model_name in compilation_results and compilation_results[model_name].success:\n", + " score += 0.3\n", + " \n", + " # Execution success (20% weight)\n", + " if model_name in execution_results and execution_results[model_name].success:\n", + " score += 0.2\n", + " \n", + " # Performance (5% weight)\n", + " if model_name in performance_metrics:\n", + " # Lower execution time is better\n", + " exec_time = performance_metrics[model_name].execution_time\n", + " if exec_time > 0:\n", + " score += 0.05 * (1.0 / (1.0 + exec_time))\n", + " \n", + " # Code quality (5% weight)\n", + " if model_name in quality_scores:\n", + " quality = quality_scores[model_name]\n", + " style_score = quality.get('style_score', 0.0)\n", + " score += 0.05 * style_score\n", + " \n", + " scores[model_name] = score\n", + " \n", + " if scores:\n", + " return max(scores, key=scores.get)\n", + " return None\n", + " \n", + " def _generate_comparison_summary(self, translation_results: Dict[str, TranslationResult],\n", + " compilation_results: Dict[str, CompilationResult],\n", + " execution_results: Dict[str, ExecutionResult],\n", + " performance_metrics: Dict[str, PerformanceMetrics],\n", + " quality_scores: Dict[str, Dict[str, Any]],\n", + " best_model: Optional[str]) -> str:\n", + " \"\"\"Generate a summary of the comparison.\"\"\"\n", + " \n", + " summary_parts = []\n", + " \n", + " # Overall success rates\n", + " successful_translations = sum(1 for r in translation_results.values() if r.success)\n", + " successful_compilations = sum(1 for r in compilation_results.values() if r.success)\n", + " successful_executions = sum(1 for r in execution_results.values() if r.success)\n", + " \n", + " summary_parts.append(f\"Translation Success: {successful_translations}/{len(translation_results)}\")\n", + " summary_parts.append(f\"Compilation Success: {successful_compilations}/{len(compilation_results)}\")\n", + " summary_parts.append(f\"Execution Success: {successful_executions}/{len(execution_results)}\")\n", + " \n", + " # Best model\n", + " if best_model:\n", + " summary_parts.append(f\"Best Model: {best_model}\")\n", + " \n", + " # Best model details\n", + " if best_model in performance_metrics:\n", + " perf = performance_metrics[best_model]\n", + " summary_parts.append(f\"Best Model Performance:\")\n", + " summary_parts.append(f\" - Execution Time: {perf.execution_time:.4f}s\")\n", + " summary_parts.append(f\" - Memory Usage: {perf.memory_usage:.2f}MB\")\n", + " summary_parts.append(f\" - Compilation Time: {perf.compilation_time:.4f}s\")\n", + " \n", + " # Quality comparison\n", + " if quality_scores:\n", + " summary_parts.append(\"Quality Scores:\")\n", + " for model, scores in quality_scores.items():\n", + " summary_parts.append(f\" {model}:\")\n", + " summary_parts.append(f\" - Lines of Code: {scores.get('lines_of_code', 0)}\")\n", + " summary_parts.append(f\" - Comment Ratio: {scores.get('comment_ratio', 0):.2%}\")\n", + " summary_parts.append(f\" - Style Score: {scores.get('style_score', 0):.2f}\")\n", + " summary_parts.append(f\" - Complexity: {scores.get('complexity_score', 0)}\")\n", + " \n", + " return \"\\n\".join(summary_parts)\n", + "\n", + "# Initialize the translator\n", + "translator = CodeTranslator()\n", + "print(f\"✅ Code Translator initialized!\")\n", + "print(f\"Available models: {translator.available_models}\")\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 5. Interactive Examples\n", + "\n", + "Let's test the translator with some example Python code!\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Example 1: Simple Fibonacci Function\n", + "python_code_1 = \"\"\"\n", + "def fibonacci(n):\n", + " if n <= 1:\n", + " return n\n", + " return fibonacci(n-1) + fibonacci(n-2)\n", + "\n", + "def main():\n", + " print(\"Fibonacci sequence:\")\n", + " for i in range(10):\n", + " result = fibonacci(i)\n", + " print(f\"fibonacci({i}) = {result}\")\n", + "\n", + "if __name__ == \"__main__\":\n", + " main()\n", + "\"\"\"\n", + "\n", + "print(\"📝 Example 1: Fibonacci Function\")\n", + "print(\"=\" * 50)\n", + "print(python_code_1)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Test the translation\n", + "if translator.available_models:\n", + " print(\"🔄 Translating Python code to C++...\")\n", + " print(\"This may take a few moments...\")\n", + " \n", + " try:\n", + " comparison = translator.translate_python_to_cpp(\n", + " python_code_1, \n", + " \"Fibonacci sequence generator\",\n", + " use_all_models=True\n", + " )\n", + " \n", + " print(f\"✅ Translation completed!\")\n", + " print(f\"🏆 Best model: {comparison.best_model}\")\n", + " print(f\"📊 Models used: {len(comparison.model_results)}\")\n", + " \n", + " # Show results for each model\n", + " for model_name, result in comparison.model_results.items():\n", + " status = \"✅ Success\" if result.success else \"❌ Failed\"\n", + " print(f\"\\n{model_name}: {status}\")\n", + " if result.success:\n", + " print(f\" Translation time: {result.translation_time:.2f}s\")\n", + " if result.token_usage:\n", + " print(f\" Token usage: {result.token_usage}\")\n", + " \n", + " # Show compilation results\n", + " if comparison.compilation_results:\n", + " print(f\"\\n🔨 Compilation Results:\")\n", + " for model_name, comp_result in comparison.compilation_results.items():\n", + " status = \"✅ Compiled\" if comp_result.success else \"❌ Failed\"\n", + " print(f\" {model_name}: {status}\")\n", + " \n", + " # Show execution results\n", + " if comparison.execution_results:\n", + " print(f\"\\n⚡ Execution Results:\")\n", + " for model_name, exec_result in comparison.execution_results.items():\n", + " status = \"✅ Executed\" if exec_result.success else \"❌ Failed\"\n", + " print(f\" {model_name}: {status}\")\n", + " if exec_result.success and exec_result.output:\n", + " print(f\" Output: {exec_result.output.strip()}\")\n", + " \n", + " except Exception as e:\n", + " print(f\"❌ Translation failed: {e}\")\n", + "else:\n", + " print(\"⚠️ No LLM models available. Please set your API keys.\")\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Display the best C++ code\n", + "if 'comparison' in locals() and comparison.best_model:\n", + " best_result = comparison.model_results[comparison.best_model]\n", + " print(f\"🏆 Best C++ Code (from {comparison.best_model}):\")\n", + " print(\"=\" * 60)\n", + " print(best_result.translated_code)\n", + " \n", + " # Show quality metrics\n", + " if comparison.best_model in comparison.quality_scores:\n", + " quality = comparison.quality_scores[comparison.best_model]\n", + " print(f\"\\n📊 Quality Metrics:\")\n", + " print(f\" Lines of code: {quality.get('lines_of_code', 0)}\")\n", + " print(f\" Comment ratio: {quality.get('comment_ratio', 0):.2%}\")\n", + " print(f\" Style score: {quality.get('style_score', 0):.2f}\")\n", + " print(f\" Complexity: {quality.get('complexity_score', 0)}\")\n", + " print(f\" Modern features: {quality.get('modern_cpp_features', [])}\")\n", + " \n", + " # Show performance metrics\n", + " if comparison.best_model in comparison.performance_metrics:\n", + " perf = comparison.performance_metrics[comparison.best_model]\n", + " print(f\"\\n⚡ Performance Metrics:\")\n", + " print(f\" Execution time: {perf.execution_time:.4f}s\")\n", + " print(f\" Memory usage: {perf.memory_usage:.2f}MB\")\n", + " print(f\" Compilation time: {perf.compilation_time:.4f}s\")\n", + " print(f\" Code size: {perf.code_size} bytes\")\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 6. Additional Examples\n", + "\n", + "Let's try a more complex example with classes and algorithms.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Example 2: Calculator Class\n", + "python_code_2 = \"\"\"\n", + "class Calculator:\n", + " def __init__(self):\n", + " self.history = []\n", + " \n", + " def add(self, a, b):\n", + " result = a + b\n", + " self.history.append(f\"{a} + {b} = {result}\")\n", + " return result\n", + " \n", + " def multiply(self, a, b):\n", + " result = a * b\n", + " self.history.append(f\"{a} * {b} = {result}\")\n", + " return result\n", + " \n", + " def get_history(self):\n", + " return self.history\n", + "\n", + "def main():\n", + " calc = Calculator()\n", + " print(\"Calculator Demo\")\n", + " print(calc.add(5, 3))\n", + " print(calc.multiply(4, 7))\n", + " print(\"History:\", calc.get_history())\n", + "\n", + "if __name__ == \"__main__\":\n", + " main()\n", + "\"\"\"\n", + "\n", + "print(\"📝 Example 2: Calculator Class\")\n", + "print(\"=\" * 50)\n", + "print(python_code_2)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Test the second example\n", + "if translator.available_models:\n", + " print(\"🔄 Translating Calculator class...\")\n", + " \n", + " try:\n", + " comparison2 = translator.translate_python_to_cpp(\n", + " python_code_2, \n", + " \"Calculator class with history tracking\",\n", + " use_all_models=True\n", + " )\n", + " \n", + " print(f\"✅ Translation completed!\")\n", + " print(f\"🏆 Best model: {comparison2.best_model}\")\n", + " \n", + " # Show summary\n", + " print(f\"\\n📊 Summary:\")\n", + " print(comparison2.comparison_summary)\n", + " \n", + " except Exception as e:\n", + " print(f\"❌ Translation failed: {e}\")\n", + "else:\n", + " print(\"⚠️ No LLM models available.\")\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 7. Summary and Results\n", + "\n", + "This notebook demonstrates a comprehensive AI-powered code translation system that:\n", + "\n", + "### Key Achievements:\n", + "- **Multi-LLM Support**: Successfully integrates OpenAI GPT, Anthropic Claude, and Google Gemini\n", + "- **C++ Compilation**: Automatically compiles and tests generated C++ code\n", + "- **Quality Analysis**: Provides detailed code quality metrics and performance benchmarking\n", + "- **Model Comparison**: Compares translation results across different AI models\n", + "- **Error Handling**: Robust error handling with detailed diagnostics\n", + "\n", + "### Use Cases:\n", + "- **Learning C++**: Translate Python code to learn C++ equivalents\n", + "- **Code Migration**: Convert Python projects to C++ for performance\n", + "- **Educational Tool**: Compare different AI models' translation quality\n", + "- **Performance Analysis**: Benchmark Python vs C++ implementations\n", + "\n", + "### Next Steps:\n", + "1. Set up your API keys for OpenAI, Anthropic, and Google\n", + "2. Run the notebook cells to test the translation system\n", + "3. Experiment with your own Python code\n", + "4. Compare results across different AI models\n", + "5. Analyze code quality and performance metrics\n", + "\n", + "**Happy coding! 🎉**\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "rom " + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": ".venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.12" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/week4/community-contributions/python_to_cpp_translator.ipynb b/week4/community-contributions/python_to_cpp_translator.ipynb new file mode 100644 index 0000000..baf38e7 --- /dev/null +++ b/week4/community-contributions/python_to_cpp_translator.ipynb @@ -0,0 +1,571 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Python to C++ Code Translator using LLMs\n", + "\n", + "This notebook translates Python code to compilable C++ using GPT, Gemini, or Claude.\n", + "\n", + "## Features:\n", + "- 🤖 Multiple LLM support (GPT, Gemini, Claude)\n", + "- ✅ Automatic compilation testing with g++\n", + "- 🔄 Comparison mode to test all LLMs\n", + "- 💬 Interactive translation mode" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Step 1: Install Required Packages\n", + "\n", + "Run this cell first to install all dependencies:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!uv add openai anthropic python-dotenv google-generativeai\n", + "#!pip install openai anthropic python-dotenv google-generativeai" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Step 2: Import Libraries" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "import subprocess\n", + "import tempfile\n", + "from pathlib import Path\n", + "from dotenv import load_dotenv\n", + "import openai\n", + "from anthropic import Anthropic\n", + "import google.generativeai as genai" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Step 3: Load API Keys\n", + "\n", + "Make sure you have a `.env` file with:\n", + "```\n", + "OPENAI_API_KEY=your_key_here\n", + "GEMINI_API_KEY=your_key_here\n", + "ANTHROPIC_API_KEY=your_key_here\n", + "```" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Load API keys from .env file\n", + "load_dotenv()\n", + "\n", + "# Initialize API clients\n", + "openai_client = openai.OpenAI(api_key=os.getenv('OPENAI_API_KEY'))\n", + "anthropic_client = Anthropic(api_key=os.getenv('ANTHROPIC_API_KEY'))\n", + "genai.configure(api_key=os.getenv('GEMINI_API_KEY'))\n", + "\n", + "print(\"✓ API keys loaded successfully\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Step 4: Define System Prompt" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "SYSTEM_PROMPT = \"\"\"You are an expert programmer that translates Python code to C++.\n", + "Translate the given Python code to efficient, compilable C++ code.\n", + "\n", + "Requirements:\n", + "- The C++ code must compile without errors\n", + "- Include all necessary headers\n", + "- Use modern C++ (C++11 or later) features where appropriate\n", + "- Add proper error handling\n", + "- Maintain the same functionality as the Python code\n", + "- Include a main() function if the Python code has executable statements\n", + "\n", + "Only return the C++ code, no explanations unless there are important notes about compilation.\"\"\"" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Step 5: LLM Translation Functions" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def translate_with_gpt(python_code, model=\"gpt-4o\"):\n", + " \"\"\"Translate Python to C++ using OpenAI's GPT models\"\"\"\n", + " try:\n", + " response = openai_client.chat.completions.create(\n", + " model=model,\n", + " messages=[\n", + " {\"role\": \"system\", \"content\": SYSTEM_PROMPT},\n", + " {\"role\": \"user\", \"content\": f\"Translate this Python code to C++:\\n\\n{python_code}\"}\n", + " ],\n", + " temperature=0.2\n", + " )\n", + " return response.choices[0].message.content\n", + " except Exception as e:\n", + " return f\"Error with GPT: {str(e)}\"\n", + "\n", + "def translate_with_gemini(python_code, model=\"gemini-2.0-flash-exp\"):\n", + " \"\"\"Translate Python to C++ using Google's Gemini\"\"\"\n", + " try:\n", + " model_instance = genai.GenerativeModel(model)\n", + " prompt = f\"{SYSTEM_PROMPT}\\n\\nTranslate this Python code to C++:\\n\\n{python_code}\"\n", + " response = model_instance.generate_content(prompt)\n", + " return response.text\n", + " except Exception as e:\n", + " return f\"Error with Gemini: {str(e)}\"\n", + "\n", + "def translate_with_claude(python_code, model=\"claude-sonnet-4-20250514\"):\n", + " \"\"\"Translate Python to C++ using Anthropic's Claude\"\"\"\n", + " try:\n", + " response = anthropic_client.messages.create(\n", + " model=model,\n", + " max_tokens=4096,\n", + " temperature=0.2,\n", + " system=SYSTEM_PROMPT,\n", + " messages=[\n", + " {\"role\": \"user\", \"content\": f\"Translate this Python code to C++:\\n\\n{python_code}\"}\n", + " ]\n", + " )\n", + " return response.content[0].text\n", + " except Exception as e:\n", + " return f\"Error with Claude: {str(e)}\"" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Step 6: Main Translation Function" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def translate_python_to_cpp(python_code, llm=\"gpt\", model=None):\n", + " \"\"\"\n", + " Translate Python code to C++ using specified LLM\n", + " \n", + " Args:\n", + " python_code (str): Python code to translate\n", + " llm (str): LLM to use ('gpt', 'gemini', or 'claude')\n", + " model (str): Specific model version (optional)\n", + " \n", + " Returns:\n", + " str: Translated C++ code\n", + " \"\"\"\n", + " print(f\"🔄 Translating with {llm.upper()}...\")\n", + " \n", + " if llm.lower() == \"gpt\":\n", + " model = model or \"gpt-4o\"\n", + " cpp_code = translate_with_gpt(python_code, model)\n", + " elif llm.lower() == \"gemini\":\n", + " model = model or \"gemini-2.0-flash-exp\"\n", + " cpp_code = translate_with_gemini(python_code, model)\n", + " elif llm.lower() == \"claude\":\n", + " model = model or \"claude-sonnet-4-20250514\"\n", + " cpp_code = translate_with_claude(python_code, model)\n", + " else:\n", + " return \"Error: Invalid LLM. Choose 'gpt', 'gemini', or 'claude'\"\n", + " \n", + " return cpp_code" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Step 7: Compilation Testing Functions" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def extract_cpp_code(text):\n", + " \"\"\"Extract C++ code from markdown code blocks if present\"\"\"\n", + " if \"```cpp\" in text:\n", + " start = text.find(\"```cpp\") + 6\n", + " end = text.find(\"```\", start)\n", + " return text[start:end].strip()\n", + " elif \"```c++\" in text:\n", + " start = text.find(\"```c++\") + 6\n", + " end = text.find(\"```\", start)\n", + " return text[start:end].strip()\n", + " elif \"```\" in text:\n", + " start = text.find(\"```\") + 3\n", + " end = text.find(\"```\", start)\n", + " return text[start:end].strip()\n", + " return text.strip()\n", + "\n", + "def compile_cpp_code(cpp_code, output_name=\"translated_program\"):\n", + " \"\"\"\n", + " Compile C++ code and return compilation status\n", + " \n", + " Args:\n", + " cpp_code (str): C++ code to compile\n", + " output_name (str): Name of output executable\n", + " \n", + " Returns:\n", + " dict: Compilation result with status and messages\n", + " \"\"\"\n", + " # Extract code from markdown if present\n", + " cpp_code = extract_cpp_code(cpp_code)\n", + " \n", + " # Create temporary directory\n", + " with tempfile.TemporaryDirectory() as tmpdir:\n", + " cpp_file = Path(tmpdir) / \"program.cpp\"\n", + " exe_file = Path(tmpdir) / output_name\n", + " \n", + " # Write C++ code to file\n", + " with open(cpp_file, 'w') as f:\n", + " f.write(cpp_code)\n", + " \n", + " # Try to compile\n", + " try:\n", + " result = subprocess.run(\n", + " ['g++', '-std=c++17', str(cpp_file), '-o', str(exe_file)],\n", + " capture_output=True,\n", + " text=True,\n", + " timeout=10\n", + " )\n", + " \n", + " if result.returncode == 0:\n", + " return {\n", + " 'success': True,\n", + " 'message': '✓ Compilation successful!',\n", + " 'executable': str(exe_file),\n", + " 'stdout': result.stdout,\n", + " 'stderr': result.stderr\n", + " }\n", + " else:\n", + " return {\n", + " 'success': False,\n", + " 'message': '✗ Compilation failed',\n", + " 'stdout': result.stdout,\n", + " 'stderr': result.stderr\n", + " }\n", + " except subprocess.TimeoutExpired:\n", + " return {\n", + " 'success': False,\n", + " 'message': '✗ Compilation timed out'\n", + " }\n", + " except FileNotFoundError:\n", + " return {\n", + " 'success': False,\n", + " 'message': '✗ g++ compiler not found. Please install g++ to compile C++ code.'\n", + " }\n", + " except Exception as e:\n", + " return {\n", + " 'success': False,\n", + " 'message': f'✗ Compilation error: {str(e)}'\n", + " }" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Step 8: Complete Pipeline" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def translate_and_compile(python_code, llm=\"gpt\", model=None, verbose=True):\n", + " \"\"\"\n", + " Translate Python to C++ and attempt compilation\n", + " \n", + " Args:\n", + " python_code (str): Python code to translate\n", + " llm (str): LLM to use\n", + " model (str): Specific model version\n", + " verbose (bool): Print detailed output\n", + " \n", + " Returns:\n", + " dict: Results including translated code and compilation status\n", + " \"\"\"\n", + " # Translate\n", + " cpp_code = translate_python_to_cpp(python_code, llm, model)\n", + " \n", + " if verbose:\n", + " print(\"\\n\" + \"=\"*60)\n", + " print(\"TRANSLATED C++ CODE:\")\n", + " print(\"=\"*60)\n", + " print(cpp_code)\n", + " print(\"=\"*60 + \"\\n\")\n", + " \n", + " # Compile\n", + " print(\"🔨 Attempting to compile...\")\n", + " compilation_result = compile_cpp_code(cpp_code)\n", + " \n", + " if verbose:\n", + " print(compilation_result['message'])\n", + " if not compilation_result['success'] and 'stderr' in compilation_result:\n", + " print(\"\\nCompilation errors:\")\n", + " print(compilation_result['stderr'])\n", + " \n", + " return {\n", + " 'cpp_code': cpp_code,\n", + " 'compilation': compilation_result\n", + " }" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Example 1: Factorial Function" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "python_code_1 = \"\"\"\n", + "def factorial(n):\n", + " if n <= 1:\n", + " return 1\n", + " return n * factorial(n - 1)\n", + "\n", + "# Test the function\n", + "print(factorial(5))\n", + "\"\"\"\n", + "\n", + "print(\"Example 1: Factorial Function\")\n", + "print(\"=\"*60)\n", + "result1 = translate_and_compile(python_code_1, llm=\"gpt\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Example 2: Sum of Squares" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "python_code_2 = \"\"\"\n", + "def sum_of_squares(numbers):\n", + " return sum(x**2 for x in numbers)\n", + "\n", + "numbers = [1, 2, 3, 4, 5]\n", + "result = sum_of_squares(numbers)\n", + "print(f\"Sum of squares: {result}\")\n", + "\"\"\"\n", + "\n", + "print(\"Example 2: Sum of Squares\")\n", + "print(\"=\"*60)\n", + "result2 = translate_and_compile(python_code_2, llm=\"claude\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Example 3: Fibonacci with Gemini" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "python_code_3 = \"\"\"\n", + "def fibonacci(n):\n", + " if n <= 1:\n", + " return n\n", + " a, b = 0, 1\n", + " for _ in range(2, n + 1):\n", + " a, b = b, a + b\n", + " return b\n", + "\n", + "print(f\"Fibonacci(10) = {fibonacci(10)}\")\n", + "\"\"\"\n", + "\n", + "print(\"Example 3: Fibonacci with Gemini\")\n", + "print(\"=\"*60)\n", + "result3 = translate_and_compile(python_code_3, llm=\"gemini\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Compare All LLMs" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def compare_llms(python_code):\n", + " \"\"\"Compare all three LLMs on the same Python code\"\"\"\n", + " llms = [\"gpt\", \"gemini\", \"claude\"]\n", + " results = {}\n", + " \n", + " for llm in llms:\n", + " print(f\"\\n{'='*60}\")\n", + " print(f\"Testing with {llm.upper()}\")\n", + " print('='*60)\n", + " results[llm] = translate_and_compile(python_code, llm=llm, verbose=False)\n", + " print(results[llm]['compilation']['message'])\n", + " \n", + " return results" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Test code for comparison\n", + "python_code_compare = \"\"\"\n", + "def is_prime(n):\n", + " if n < 2:\n", + " return False\n", + " for i in range(2, int(n**0.5) + 1):\n", + " if n % i == 0:\n", + " return False\n", + " return True\n", + "\n", + "primes = [x for x in range(2, 20) if is_prime(x)]\n", + "print(f\"Primes under 20: {primes}\")\n", + "\"\"\"\n", + "\n", + "print(\"COMPARING ALL LLMs\")\n", + "print(\"=\"*60)\n", + "comparison_results = compare_llms(python_code_compare)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Interactive Translation Mode\n", + "\n", + "Use this cell to translate your own Python code interactively:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Your custom Python code here\n", + "your_python_code = \"\"\"\n", + "# Paste your Python code here\n", + "def hello_world():\n", + " print(\"Hello, World!\")\n", + "\n", + "hello_world()\n", + "\"\"\"\n", + "\n", + "# Choose your LLM: \"gpt\", \"gemini\", or \"claude\"\n", + "chosen_llm = \"gpt\"\n", + "\n", + "result = translate_and_compile(your_python_code, llm=chosen_llm)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Summary\n", + "\n", + "You now have a complete Python to C++ translator! \n", + "\n", + "### Main Functions:\n", + "- `translate_python_to_cpp(code, llm, model)` - Translate only\n", + "- `translate_and_compile(code, llm, model)` - Translate and compile\n", + "- `compare_llms(code)` - Compare all three LLMs\n", + "\n", + "### Supported LLMs:\n", + "- **gpt** - OpenAI GPT-4o\n", + "- **gemini** - Google Gemini 2.0 Flash\n", + "- **claude** - Anthropic Claude Sonnet 4\n", + "\n", + "Happy translating! 🚀" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": ".venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.12" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/week4/community-contributions/tochi/code_converter.ipynb b/week4/community-contributions/tochi/code_converter.ipynb new file mode 100644 index 0000000..5101d61 --- /dev/null +++ b/week4/community-contributions/tochi/code_converter.ipynb @@ -0,0 +1,569 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "c1fcc6e9", + "metadata": {}, + "source": [ + "# Code Converter - Python to TypeScript Code\n", + "\n", + "This implementation, converts python code to optimized TypeScript Code, and runs the function" + ] + }, + { + "cell_type": "markdown", + "id": "16b6b063", + "metadata": {}, + "source": [ + "## Set up and imports\n" + ] + }, + { + "cell_type": "code", + "execution_count": 115, + "id": "b3dc394c", + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "import os\n", + "import io\n", + "import sys\n", + "from dotenv import load_dotenv\n", + "from openai import OpenAI\n", + "import subprocess\n", + "from IPython.display import Markdown, display, display_markdown\n", + "from system_info import retrieve_system_info\n", + "import gradio as gr" + ] + }, + { + "cell_type": "markdown", + "id": "1c9a0936", + "metadata": {}, + "source": [ + "# Initializing the access keys" + ] + }, + { + "cell_type": "code", + "execution_count": 116, + "id": "fac104ec", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "OpenAI API Key exists and begins sk-proj-\n" + ] + } + ], + "source": [ + "load_dotenv(override=True)\n", + "openai_api_key = os.getenv(\"OPENAI_API_KEY\")\n", + "\n", + "if openai_api_key:\n", + " print(f\"OpenAI API Key exists and begins {openai_api_key[:8]}\")\n", + "else:\n", + " print(\"OpenAI API Key not set. Check your engironment variables and try again\")" + ] + }, + { + "cell_type": "markdown", + "id": "5932182f", + "metadata": {}, + "source": [ + "# Connecting to client libraries" + ] + }, + { + "cell_type": "code", + "execution_count": 117, + "id": "4000f231", + "metadata": {}, + "outputs": [], + "source": [ + "openai = OpenAI()" + ] + }, + { + "cell_type": "code", + "execution_count": 118, + "id": "51c67ac0", + "metadata": {}, + "outputs": [], + "source": [ + "# contants\n", + "OPENAI_MODEL= \"gpt-5-nano\"" + ] + }, + { + "cell_type": "code", + "execution_count": 119, + "id": "ab4342bf", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'os': {'system': 'Darwin',\n", + " 'arch': 'arm64',\n", + " 'release': '24.5.0',\n", + " 'version': 'Darwin Kernel Version 24.5.0: Tue Apr 22 19:48:46 PDT 2025; root:xnu-11417.121.6~2/RELEASE_ARM64_T8103',\n", + " 'kernel': '24.5.0',\n", + " 'distro': None,\n", + " 'wsl': False,\n", + " 'rosetta2_translated': False,\n", + " 'target_triple': 'arm64-apple-darwin24.5.0'},\n", + " 'package_managers': ['xcode-select (CLT)', 'brew'],\n", + " 'cpu': {'brand': 'Apple M1',\n", + " 'cores_logical': 8,\n", + " 'cores_physical': 8,\n", + " 'simd': []},\n", + " 'toolchain': {'compilers': {'gcc': 'Apple clang version 17.0.0 (clang-1700.0.13.3)',\n", + " 'g++': 'Apple clang version 17.0.0 (clang-1700.0.13.3)',\n", + " 'clang': 'Apple clang version 17.0.0 (clang-1700.0.13.3)',\n", + " 'msvc_cl': ''},\n", + " 'build_tools': {'cmake': '', 'ninja': '', 'make': 'GNU Make 3.81'},\n", + " 'linkers': {'ld_lld': ''}}}" + ] + }, + "execution_count": 119, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "system_info = retrieve_system_info()\n", + "system_info" + ] + }, + { + "cell_type": "code", + "execution_count": 120, + "id": "1a1c1324", + "metadata": {}, + "outputs": [], + "source": [ + "message = f\"\"\"\n", + "Here is a report of the system information for my computer.\n", + "I want to run a TypeScript compiler to compile a single TypeScript file called main.cpp and then execute it in the simplest way possible.\n", + "Please reply with whether I need to install any TypeScript compiler to do this. If so, please provide the simplest step by step instructions to do so.\n", + "\n", + "If I'm already set up to compile TypeScript code, then I'd like to run something like this in Python to compile and execute the code:\n", + "```python\n", + "compile_command = # something here - to achieve the fastest possible runtime performance\n", + "compile_result = subprocess.run(compile_command, check=True, text=True, capture_output=True)\n", + "run_command = # something here\n", + "run_result = subprocess.run(run_command, check=True, text=True, capture_output=True)\n", + "return run_result.stdout\n", + "```\n", + "Please tell me exactly what I should use for the compile_command and run_command.\n", + "\n", + "System information:\n", + "{system_info}\n", + "\"\"\"" + ] + }, + { + "cell_type": "code", + "execution_count": 121, + "id": "439015c1", + "metadata": {}, + "outputs": [ + { + "data": { + "text/markdown": [ + "Short answer:\n", + "- Yes, to compile TypeScript you need a TypeScript compiler (tsc). On macOS you’ll typically install Node.js first, then install TypeScript.\n", + "- Important: main.cpp sounds like a C++ file. The TypeScript compiler (tsc) cannot compile .cpp. If you want to use TypeScript, rename the file to main.ts (and ensure its contents are TypeScript). If you actually meant C++, use a C++ compiler instead (clang/g++).\n", + "\n", + "Step-by-step to set up TypeScript (simplest path on your system):\n", + "1) Install Node.js (which also installs npm)\n", + "- brew update\n", + "- brew install node\n", + "\n", + "2) Install the TypeScript compiler globally\n", + "- npm install -g typescript\n", + "\n", + "3) Verify installations\n", + "- node -v\n", + "- npm -v\n", + "- tsc -v\n", + "\n", + "4) Compile and run a TypeScript file (assuming your file is main.ts)\n", + "- tsc main.ts\n", + "- node main.js\n", + "\n", + "Notes:\n", + "- If your file is indeed C++ (main.cpp), you cannot compile it with tsc. To compile C++, use clang++ (on macOS) or g++:\n", + " - clang++ -std=c++17 main.cpp -o main\n", + " - ./main\n", + "\n", + "Python integration (fill-in for your example)\n", + "- If you have a TypeScript file named main.ts and you want to compile it to JavaScript and then run it with Node, use:\n", + " compile_command = [\"tsc\", \"main.ts\"]\n", + " run_command = [\"node\", \"main.js\"]\n", + "\n", + "- If you want to show a single command in Python that compiles and runs in one go (still two steps because TS compiles to JS first):\n", + " compile_command = [\"tsc\", \"main.ts\"]\n", + " run_command = [\"node\", \"main.js\"]\n", + "\n", + "- If you truly want to bypass TypeScript and run C++ instead (not TypeScript):\n", + " compile_command = [\"clang++\", \"-std=c++17\", \"main.cpp\", \"-o\", \"main\"]\n", + " run_command = [\"./main\"]\n", + "\n", + "If you’d like, tell me whether main.cpp is meant to be C++ or you actually have a TypeScript file named main.ts, and I can tailor the exact commands." + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "response = openai.chat.completions.create(model=OPENAI_MODEL, messages=[{\"role\":\"user\", \"content\":message}])\n", + "display(Markdown(response.choices[0].message.content))" + ] + }, + { + "cell_type": "code", + "execution_count": 122, + "id": "576cb5fa", + "metadata": {}, + "outputs": [], + "source": [ + "compile_command = [\"tsc\", \"main.ts\", \"--target\", \"ES2020\", \"--module\", \"commonjs\"]\n", + "run_command = [\"ts-node\", \"main.ts\"]" + ] + }, + { + "cell_type": "markdown", + "id": "01b03700", + "metadata": {}, + "source": [ + "## System and user prompts for the code converter" + ] + }, + { + "cell_type": "code", + "execution_count": 123, + "id": "255e318b", + "metadata": {}, + "outputs": [], + "source": [ + "system_prompt = \"\"\"\n", + "Your task is to convert Python code into high performance TypeScript code.\n", + "Respond only with TypeScript code. Do not provide any explanation other than occasional comments.\n", + "The TypeScript response needs to produce an identical output in the fastest possible time.\n", + "\"\"\"\n", + "\n", + "\n", + "def user_prompt_for(python):\n", + " return f\"\"\" \n", + " port this Python code to TypeScript with the fastest possible implementation that produces identical output in the least time.\n", + "\n", + " The system information is \n", + "\n", + " {system_info}\n", + "\n", + " Your response will be written to a file called main.ts and then compile and ecexted; the compilation command is:\n", + "\n", + " {compile_command}\n", + "\n", + " Respond only with C++ code.\n", + " Python code to port:\n", + "\n", + " ```python\n", + " {python}\n", + " ```\n", + "\n", + "\"\"\"" + ] + }, + { + "cell_type": "code", + "execution_count": 124, + "id": "09da7cb1", + "metadata": {}, + "outputs": [], + "source": [ + "def messages_for(python):\n", + " return [\n", + " {\"role\": \"system\", \"content\": system_prompt},\n", + " {\"role\": \"user\", \"content\": user_prompt_for(python)},\n", + " ]" + ] + }, + { + "cell_type": "code", + "execution_count": 125, + "id": "abcdb617", + "metadata": {}, + "outputs": [], + "source": [ + "def write_output(code):\n", + " with open(\"main.ts\", \"w\", encoding=\"utf-8\") as f:\n", + " f.write(code)" + ] + }, + { + "cell_type": "code", + "execution_count": 126, + "id": "c7a32d5f", + "metadata": {}, + "outputs": [], + "source": [ + "def convert(python):\n", + " reasoning_effort = \"high\"\n", + " response = openai.chat.completions.create(\n", + " model=OPENAI_MODEL,\n", + " messages=messages_for(python),\n", + " reasoning_effort=reasoning_effort,\n", + " )\n", + " reply = response.choices[0].message.content\n", + " reply = reply.replace(\"```ts\", \"\").replace(\"```\", \"\")\n", + " return reply" + ] + }, + { + "cell_type": "code", + "execution_count": 127, + "id": "59a7ec1f", + "metadata": {}, + "outputs": [], + "source": [ + "pi = \"\"\"\n", + "import time\n", + "\n", + "def calculate(iterations, param1, param2):\n", + " result = 1.0\n", + " for i in range(1, iterations+1):\n", + " j = i * param1 - param2\n", + " result -= (1/j)\n", + " j = i * param1 + param2\n", + " result += (1/j)\n", + " return result\n", + "\n", + "start_time = time.time()\n", + "result = calculate(200_000_000, 4, 1) * 4\n", + "end_time = time.time()\n", + "\n", + "print(f\"Result: {result:.12f}\")\n", + "print(f\"Execution Time: {(end_time - start_time):.6f} seconds\")\n", + "\"\"\"" + ] + }, + { + "cell_type": "code", + "execution_count": 128, + "id": "6856393b", + "metadata": {}, + "outputs": [], + "source": [ + "def run_python(code):\n", + " globals_dict = {\"__builtins__\": __builtins__}\n", + "\n", + " buffer = io.StringIO()\n", + " old_stdout = sys.stdout\n", + " sys.stdout = buffer\n", + "\n", + " try:\n", + " exec(code, globals_dict)\n", + " output = buffer.getvalue()\n", + " except Exception as e:\n", + " output = f\"Error: {e}\"\n", + " finally:\n", + " sys.stdout = old_stdout\n", + "\n", + " return output" + ] + }, + { + "cell_type": "code", + "execution_count": 129, + "id": "c51fa5ea", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "'Result: 3.141592656089\\nExecution Time: 19.478347 seconds\\n'" + ] + }, + "execution_count": 129, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "run_python(pi)" + ] + }, + { + "cell_type": "code", + "execution_count": 130, + "id": "69eb2304", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "\"import { performance } from 'perf_hooks';\\n\\nfunction digamma(z: number): number {\\n let acc = 0;\\n while (z < 7) {\\n acc -= 1 / z;\\n z += 1;\\n }\\n const z2 = z * z;\\n const z4 = z2 * z2;\\n const z6 = z4 * z2;\\n const z8 = z4 * z4;\\n const z10 = z8 * z2;\\n const z12 = z10 * z2;\\n const series =\\n Math.log(z)\\n - 1 / (2 * z)\\n - 1 / (12 * z2)\\n + 1 / (120 * z4)\\n - 1 / (252 * z6)\\n + 1 / (240 * z8)\\n - 5 / (660 * z10)\\n + 691 / (32760 * z12);\\n return acc + series;\\n}\\n\\nconst N = 200_000_000;\\n\\nconst t0 = performance.now();\\nconst result =\\n 4 - digamma(N + 0.75) + digamma(0.75) + digamma(N + 1.25) - digamma(1.25);\\nconst t1 = performance.now();\\n\\nconsole.log(`Result: ${result.toFixed(12)}`);\\nconsole.log(`Execution Time: ${((t1 - t0) / 1000).toFixed(6)} seconds`);\"" + ] + }, + "execution_count": 130, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "convert(pi)" + ] + }, + { + "cell_type": "code", + "execution_count": 131, + "id": "2ea56d95", + "metadata": {}, + "outputs": [], + "source": [ + " \n", + "def run_typescript(code):\n", + " write_output(code)\n", + " try:\n", + " subprocess.run(compile_command, check=True, text=True, capture_output=True)\n", + " run_result = subprocess.run(run_command, check=True, text=True, capture_output=True)\n", + " return run_result.stdout\n", + " except subprocess.CalledProcessError as e:\n", + " return f\"An error occurred:\\n{e.stderr}\"" + ] + }, + { + "cell_type": "code", + "execution_count": 132, + "id": "79d6bd87", + "metadata": {}, + "outputs": [], + "source": [ + "# run_typescript()" + ] + }, + { + "cell_type": "markdown", + "id": "b4799b88", + "metadata": {}, + "source": [ + "## User Interface" + ] + }, + { + "cell_type": "code", + "execution_count": 133, + "id": "8486ce70", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "* Running on local URL: http://127.0.0.1:7864\n", + "* To create a public link, set `share=True` in `launch()`.\n" + ] + }, + { + "data": { + "text/html": [ + "
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [] + }, + "execution_count": 133, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "with gr.Blocks(\n", + " theme=gr.themes.Monochrome(), title=\"Port from Python to TypeScript\"\n", + ") as ui:\n", + " with gr.Row(equal_height=True):\n", + " with gr.Column(scale=6):\n", + " python = gr.Code(\n", + " label=\"Python Original Code\",\n", + " value=pi,\n", + " language=\"python\",\n", + " lines=30,\n", + " )\n", + " with gr.Column(scale=6):\n", + " ts = gr.Code(\n", + " label=\"TypeScript (generated)\", value=\"\", language=\"cpp\", lines=26\n", + " )\n", + " with gr.Row(elem_classes=[\"controls\"]):\n", + " python_run = gr.Button(\"Run Python\", elem_classes=[\"run-btn\", \"py\"])\n", + " port = gr.Button(\"Convert to TS\", elem_classes=[\"convert-btn\"])\n", + " ts_run = gr.Button(\"Run TS\", elem_classes=[\"run-btn\", \"ts\"])\n", + "\n", + " with gr.Row(equal_height=True):\n", + " with gr.Column(scale=6):\n", + " python_out = gr.TextArea(label=\"Python Result\", lines=10)\n", + " with gr.Column(scale=6):\n", + " ts_out = gr.TextArea(label=\"TS output\", lines=10)\n", + "\n", + " port.click(fn=convert, inputs=[python], outputs=[ts])\n", + " python_run.click(fn=run_python, inputs=[python], outputs=[python_out])\n", + " ts_run.click(fn=run_typescript, inputs=[ts], outputs=[ts_out])\n", + " \n", + " \n", + "ui.launch(inbrowser=True)" + ] + }, + { + "cell_type": "markdown", + "id": "4663a174", + "metadata": {}, + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9033e421", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": ".venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.4" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/week4/community-contributions/w4d5-Trade.ipynb b/week4/community-contributions/w4d5-Trade.ipynb new file mode 100644 index 0000000..3a57afa --- /dev/null +++ b/week4/community-contributions/w4d5-Trade.ipynb @@ -0,0 +1,1833 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Trading Code Generator\n", + "\n", + "This notebook creates a code generator that produces trading code to buy and sell equities in a simulated environment based on free APIs. It uses Gradio for the UI, similar to the approach in day5.ipynb.\n" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "import io\n", + "import sys\n", + "import time\n", + "import random\n", + "import numpy as np\n", + "from dotenv import load_dotenv\n", + "from openai import OpenAI\n", + "import gradio as gr\n", + "from IPython.display import display\n", + "from huggingface_hub import InferenceClient\n" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "OpenAI API Key exists and begins sk-proj-\n", + "Hugging Face Token exists and begins hf_fNncb\n" + ] + } + ], + "source": [ + "load_dotenv(override=True)\n", + "openai_api_key = os.getenv('OPENAI_API_KEY')\n", + "hf_token = os.getenv('HF_TOKEN')\n", + "\n", + "if openai_api_key:\n", + " print(f\"OpenAI API Key exists and begins {openai_api_key[:8]}\")\n", + "else:\n", + " print(\"OpenAI API Key not set\")\n", + " \n", + "if hf_token:\n", + " print(f\"Hugging Face Token exists and begins {hf_token[:8]}\")\n", + "else:\n", + " print(\"Hugging Face Token not set\")\n" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [], + "source": [ + "openai_client = OpenAI()\n", + "hf_client = InferenceClient(token=hf_token)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 46, + "metadata": {}, + "outputs": [], + "source": [ + "models = [\"gpt-4o\", \"gpt-3.5-turbo\", \"meta-llama/Llama-2-70b-chat-hf\"]\n", + "\n", + "def generate_with_openai(model, messages):\n", + " response = openai_client.chat.completions.create(\n", + " model=model, \n", + " messages=messages\n", + " )\n", + " return response.choices[0].message.content\n", + "\n", + "def generate_with_hf(model, messages):\n", + " prompt = \"\"\n", + " for msg in messages:\n", + " role = msg[\"role\"]\n", + " content = msg[\"content\"]\n", + " if role == \"system\":\n", + " prompt += f\"[INST] {content} [/INST]\\n\"\n", + " elif role == \"user\":\n", + " prompt += f\"[INST] {content} [/INST]\\n\"\n", + " else:\n", + " prompt += f\"{content}\\n\"\n", + " \n", + " response = hf_client.text_generation(\n", + " prompt,\n", + " model=model,\n", + " max_new_tokens=1024,\n", + " temperature=0.7,\n", + " repetition_penalty=1.2\n", + " )\n", + " return response\n" + ] + }, + { + "cell_type": "code", + "execution_count": 47, + "metadata": {}, + "outputs": [], + "source": [ + "CSS = \"\"\"\n", + ":root {\n", + " --py-color: #209dd7;\n", + " --trading-color: #27ae60;\n", + " --accent: #753991;\n", + " --card: #161a22;\n", + " --text: #e9eef5;\n", + "}\n", + "\n", + "/* Full-width layout */\n", + ".gradio-container {\n", + " max-width: 100% !important;\n", + " padding: 0 40px !important;\n", + "}\n", + "\n", + "/* Code card styling */\n", + ".card {\n", + " background: var(--card);\n", + " border: 1px solid rgba(255,255,255,.08);\n", + " border-radius: 14px;\n", + " padding: 10px;\n", + "}\n", + "\n", + "/* Make code block scrollable but fixed height */\n", + "#code-block {\n", + " max-height: 400px !important;\n", + " overflow-y: auto !important;\n", + "}\n", + "\n", + "#code-block .cm-editor {\n", + " height: 400px !important;\n", + "}\n", + "\n", + "/* Buttons */\n", + ".generate-btn button {\n", + " background: var(--accent) !important;\n", + " border-color: rgba(255,255,255,.12) !important;\n", + " color: white !important;\n", + " font-weight: 700;\n", + "}\n", + ".run-btn button {\n", + " background: #202631 !important;\n", + " color: var(--text) !important;\n", + " border-color: rgba(255,255,255,.12) !important;\n", + "}\n", + ".run-btn.py button:hover { box-shadow: 0 0 0 2px var(--py-color) inset; }\n", + ".run-btn.trading button:hover { box-shadow: 0 0 0 2px var(--trading-color) inset; }\n", + ".generate-btn button:hover { box-shadow: 0 0 0 2px var(--accent) inset; }\n", + "\n", + "/* Outputs with color tint */\n", + ".py-out textarea {\n", + " background: linear-gradient(180deg, rgba(32,157,215,.18), rgba(32,157,215,.10));\n", + " border: 1px solid rgba(32,157,215,.35) !important;\n", + " color: rgba(32,157,215,1) !important;\n", + " font-weight: 600;\n", + "}\n", + ".trading-out textarea {\n", + " background: linear-gradient(180deg, rgba(39,174,96,.18), rgba(39,174,96,.10));\n", + " border: 1px solid rgba(39,174,96,.35) !important;\n", + " color: rgba(39,174,96,1) !important;\n", + " font-weight: 600;\n", + "}\n", + "\n", + "/* Align controls neatly */\n", + ".controls .wrap {\n", + " gap: 10px;\n", + " justify-content: center;\n", + " align-items: center;\n", + "}\n", + "\"\"\"\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "system_prompt = \"\"\"\n", + "You are an expert algorithmic trading code generator. Generate clean, bug-free Python code for trading strategies.\n", + "\n", + "Generate code that:\n", + "1. Uses synthetic data generation only - no API calls\n", + "2. Implements the specified trading strategy\n", + "3. Uses proper error handling\n", + "4. Visualizes strategy performance with buy/sell signals\n", + "5. Calculates performance metrics\n", + "6. Handles edge cases properly\n", + "\n", + "REQUIREMENTS:\n", + "1. Include if __name__ == \"__main__\": block that executes immediately\n", + "2. Define all variables before use\n", + "3. Pass parameters between functions, avoid global variables\n", + "4. NO explanatory text outside of code\n", + "5. NO markdown blocks or language indicators\n", + "6. Code must execute without user input\n", + "7. Use str() for pandas objects in f-strings\n", + "8. Use .copy() for DataFrame views that will be modified\n", + "9. Include min_periods in rolling calculations\n", + "10. Check array lengths before scatter plots\n", + "11. Configure logging properly\n", + "12. Include helper functions for formatting and plotting\n", + "\n", + "Respond ONLY with Python code. No explanations or markdown.\n", + "\"\"\"\n", + "\n", + "def user_prompt_for(description):\n", + " return f\"\"\"\n", + "Generate Python code for a trading strategy:\n", + "\n", + "{description}\n", + "\n", + "Requirements:\n", + "1. Use synthetic data generation only\n", + "2. Implement the strategy exactly as described\n", + "3. Include backtesting functionality\n", + "4. Visualize results with matplotlib\n", + "5. Calculate performance metrics\n", + "6. Handle all edge cases\n", + "7. No comments needed\n", + "\n", + "Make the code complete and runnable as-is with all necessary imports.\n", + "\"\"\"\n" + ] + }, + { + "cell_type": "code", + "execution_count": 87, + "metadata": {}, + "outputs": [], + "source": [ + "def messages_for(description):\n", + " return [\n", + " {\"role\": \"system\", \"content\": system_prompt},\n", + " {\"role\": \"user\", \"content\": user_prompt_for(description)}\n", + " ]\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def validate_code(code):\n", + " issues = []\n", + " if \"import yfinance\" not in code and \"from yfinance\" not in code:\n", + " issues.append(\"Missing yfinance import\")\n", + " if \"import matplotlib\" not in code and \"from matplotlib\" not in code:\n", + " issues.append(\"Missing matplotlib import\")\n", + " if \"__name__ == \\\"__main__\\\"\" not in code and \"__name__ == '__main__'\" not in code:\n", + " issues.append(\"Missing if __name__ == '__main__' block\")\n", + " if \"f\\\"\" in code or \"f'\" in code:\n", + " lines = code.split('\\n')\n", + " for i, line in enumerate(lines):\n", + " if ('f\"' in line or \"f'\" in line) and ('data[' in line or '.iloc' in line or '.loc' in line):\n", + " issues.append(f\"Potentially unsafe f-string formatting with pandas objects on line {i+1}\")\n", + " if \"try:\" in code and \"except\" not in code:\n", + " issues.append(\"Try block without except clause\")\n", + " if \"rolling\" in code and \"min_periods\" not in code:\n", + " issues.append(\"Rolling window without min_periods parameter (may produce NaN values)\")\n", + " if \".loc\" in code and \"iloc\" not in code and \"copy()\" not in code:\n", + " issues.append(\"Potential pandas SettingWithCopyWarning - consider using .copy() before modifications\")\n", + " lines = code.split('\\n')\n", + " defined_vars = set()\n", + " for line in lines:\n", + " if line.strip().startswith('#') or not line.strip():\n", + " continue\n", + " if '=' in line and not line.strip().startswith('if') and not line.strip().startswith('elif') and not line.strip().startswith('while'):\n", + " var_name = line.split('=')[0].strip()\n", + " if var_name:\n", + " defined_vars.add(var_name)\n", + " if issues:\n", + " return False, issues\n", + " return True, []\n", + "\n", + "def generate_trading_code(model, description, force_gpt4=False):\n", + " messages = messages_for(description)\n", + " if force_gpt4:\n", + " try:\n", + " reply = generate_with_openai(\"gpt-4o\", messages)\n", + " except Exception as e:\n", + " print(f\"Error using GPT-4o: {e}. Falling back to selected model.\")\n", + " if \"gpt\" in model.lower():\n", + " reply = generate_with_openai(model, messages)\n", + " else:\n", + " reply = generate_with_hf(model, messages)\n", + " else:\n", + " if \"gpt\" in model.lower():\n", + " reply = generate_with_openai(model, messages)\n", + " else:\n", + " reply = generate_with_hf(model, messages)\n", + " reply = reply.replace('```python','').replace('```','')\n", + " is_valid, issues = validate_code(reply)\n", + " max_attempts = 3\n", + " attempt = 0\n", + " fix_model = \"gpt-4o\" if force_gpt4 else model\n", + " while not is_valid and attempt < max_attempts and (\"gpt\" in model.lower() or force_gpt4):\n", + " attempt += 1\n", + " fix_messages = messages.copy()\n", + " fix_messages.append({\"role\": \"assistant\", \"content\": reply})\n", + " fix_request = f\"\"\"The code has the following issues that need to be fixed:\n", + "{chr(10).join([f\"- {issue}\" for issue in issues])}\n", + "\n", + "Please provide a completely corrected version that addresses these issues. Make sure to:\n", + "\n", + "1. Avoid using f-strings with pandas Series or DataFrame objects directly\n", + "2. Always handle NaN values in calculations with proper checks\n", + "3. Use proper error handling with try/except blocks around all API calls and calculations\n", + "4. Include min_periods parameter in rolling window calculations\n", + "5. Use .copy() when creating views of DataFrames that will be modified\n", + "6. Make sure all variables are properly defined before use\n", + "7. Add yfinance timeout settings: yf.set_timeout(30)\n", + "8. Add proper logging for all steps\n", + "9. Use synthetic data generation as a fallback if API calls fail\n", + "10. Include proper if __name__ == \"__main__\" block\n", + "\n", + "Return ONLY the corrected code with no explanation or markdown formatting.\n", + "\"\"\"\n", + " fix_messages.append({\"role\": \"user\", \"content\": fix_request})\n", + " try:\n", + " if force_gpt4:\n", + " fixed_reply = generate_with_openai(\"gpt-4o\", fix_messages)\n", + " else:\n", + " if \"gpt\" in model.lower():\n", + " fixed_reply = generate_with_openai(model, fix_messages)\n", + " else:\n", + " fixed_reply = generate_with_hf(model, fix_messages)\n", + " fixed_reply = fixed_reply.replace('```python','').replace('```','')\n", + " is_fixed_valid, fixed_issues = validate_code(fixed_reply)\n", + " if is_fixed_valid or len(fixed_issues) < len(issues):\n", + " reply = fixed_reply\n", + " is_valid = is_fixed_valid\n", + " issues = fixed_issues\n", + " except Exception as e:\n", + " print(f\"Error during fix attempt {attempt}: {e}\")\n", + " reply = add_safety_features(reply)\n", + " return reply\n", + "\n", + "def add_safety_features(code):\n", + " if \"pandas\" in code:\n", + " safety_imports = \"\"\"\n", + "import pandas as pd\n", + "pd.set_option('display.float_format', '{:.5f}'.format)\n", + "\n", + "def safe_format(obj):\n", + " if isinstance(obj, (pd.Series, pd.DataFrame)):\n", + " return str(obj)\n", + " return obj\n", + "\"\"\"\n", + " import_lines = [i for i, line in enumerate(code.split('\\n')) if 'import' in line]\n", + " if import_lines:\n", + " lines = code.split('\\n')\n", + " lines.insert(import_lines[-1] + 1, safety_imports)\n", + " code = '\\n'.join(lines)\n", + " code = code.replace(\"yf.set_timeout(30)\", \"\")\n", + " code = code.replace(\"yf.pdr_override()\", \"\")\n", + " lines = code.split('\\n')\n", + " for i, line in enumerate(lines):\n", + " if 'f\"' in line or \"f'\" in line:\n", + " if any(term in line for term in ['data[', '.iloc', '.loc', 'Series', 'DataFrame']):\n", + " for term in ['.mean()', '.sum()', '.std()', '.min()', '.max()']:\n", + " if term in line:\n", + " lines[i] = line.replace(f\"{term}\", f\"{term})\")\n", + " lines[i] = lines[i].replace(\"f\\\"\", \"f\\\"{safe_format(\")\n", + " lines[i] = lines[i].replace(\"f'\", \"f'{safe_format(\")\n", + " code = '\\n'.join(lines)\n", + " if \"plt.scatter\" in code or \".scatter\" in code:\n", + " scatter_safety = \"\"\"\n", + "def safe_scatter(ax, x, y, *args, **kwargs):\n", + " if len(x) != len(y):\n", + " min_len = min(len(x), len(y))\n", + " x = x[:min_len]\n", + " y = y[:min_len]\n", + " if len(x) == 0 or len(y) == 0:\n", + " return None\n", + " return ax.scatter(x, y, *args, **kwargs)\n", + "\"\"\"\n", + " func_lines = [i for i, line in enumerate(code.split('\\n')) if line.startswith('def ')]\n", + " if func_lines:\n", + " lines = code.split('\\n')\n", + " lines.insert(func_lines[0], scatter_safety)\n", + " code = '\\n'.join(lines)\n", + " code = code.replace(\"plt.scatter(\", \"safe_scatter(plt.gca(), \")\n", + " code = code.replace(\".scatter(\", \"safe_scatter(\")\n", + " if \"yfinance\" in code and \"generate_synthetic_data\" not in code:\n", + " synthetic_data_func = \"\"\"\n", + "def generate_synthetic_data(ticker='AAPL', start_date=None, end_date=None, days=252, seed=42):\n", + " import numpy as np\n", + " import pandas as pd\n", + " from datetime import datetime, timedelta\n", + " if start_date is None:\n", + " end_date = datetime.now()\n", + " start_date = end_date - timedelta(days=days)\n", + " elif end_date is None:\n", + " if isinstance(start_date, str):\n", + " start_date = pd.to_datetime(start_date)\n", + " end_date = datetime.now()\n", + " np.random.seed(seed)\n", + " if isinstance(start_date, str):\n", + " start = pd.to_datetime(start_date)\n", + " else:\n", + " start = start_date\n", + " if isinstance(end_date, str):\n", + " end = pd.to_datetime(end_date)\n", + " else:\n", + " end = end_date\n", + " days = (end - start).days + 1\n", + " price = 100\n", + " prices = [price]\n", + " for _ in range(days):\n", + " change = np.random.normal(0, 0.01)\n", + " price *= (1 + change)\n", + " prices.append(price)\n", + " dates = pd.date_range(start=start, end=end, periods=len(prices))\n", + " df = pd.DataFrame({\n", + " 'Open': prices[:-1],\n", + " 'High': [p * 1.01 for p in prices[:-1]],\n", + " 'Low': [p * 0.99 for p in prices[:-1]],\n", + " 'Close': prices[1:],\n", + " 'Volume': [np.random.randint(1000000, 10000000) for _ in range(len(prices)-1)]\n", + " }, index=dates[:-1])\n", + " return df\n", + "\"\"\"\n", + " func_lines = [i for i, line in enumerate(code.split('\\n')) if line.startswith('def ')]\n", + " if func_lines:\n", + " lines = code.split('\\n')\n", + " lines.insert(func_lines[0], synthetic_data_func)\n", + " code = '\\n'.join(lines)\n", + " if \"logging\" in code and \"basicConfig\" not in code:\n", + " logging_config = \"\"\"\n", + "import logging\n", + "logging.basicConfig(\n", + " level=logging.INFO,\n", + " format='[%(asctime)s] %(levelname)s: %(message)s',\n", + " datefmt='%H:%M:%S'\n", + ")\n", + "\"\"\"\n", + " import_lines = [i for i, line in enumerate(code.split('\\n')) if 'import' in line]\n", + " if import_lines:\n", + " lines = code.split('\\n')\n", + " lines.insert(import_lines[-1] + 1, logging_config)\n", + " code = '\\n'.join(lines)\n", + " if \"yfinance\" in code and \"try:\" not in code:\n", + " lines = code.split('\\n')\n", + " for i, line in enumerate(lines):\n", + " if \"yf.download\" in line and \"try:\" not in lines[max(0, i-5):i]:\n", + " indent = len(line) - len(line.lstrip())\n", + " indent_str = \" \" * indent\n", + " lines[i] = f\"{indent_str}try:\\n{indent_str} {line}\\n{indent_str}except Exception as e:\\n{indent_str} logging.error(f\\\"Error fetching data: {{e}}\\\")\\n{indent_str} # Use synthetic data as fallback\\n{indent_str} data = generate_synthetic_data(ticker, start_date, end_date)\"\n", + " code = '\\n'.join(lines)\n", + " break\n", + " if \"synthetic data\" in code.lower() and \"yf.download\" in code:\n", + " lines = code.split('\\n')\n", + " for i, line in enumerate(lines):\n", + " if \"yf.download\" in line:\n", + " indent = len(line) - len(line.lstrip())\n", + " indent_str = \" \" * indent\n", + " comment = f\"{indent_str}# Using synthetic data instead of API call\\n\"\n", + " synthetic = f\"{indent_str}data = generate_synthetic_data(ticker, start_date, end_date)\\n\"\n", + " lines[i] = f\"{indent_str}# {line.strip()} # Commented out to avoid API issues\"\n", + " lines.insert(i+1, comment + synthetic)\n", + " code = '\\n'.join(lines)\n", + " break\n", + " if \"plt.figure\" in code:\n", + " lines = code.split('\\n')\n", + " for i, line in enumerate(lines):\n", + " if \"plt.figure\" in line and \"try:\" not in lines[max(0, i-5):i]:\n", + " indent = len(line) - len(line.lstrip())\n", + " indent_str = \" \" * indent\n", + " try_line = f\"{indent_str}try:\\n{indent_str} \"\n", + " except_line = f\"\\n{indent_str}except Exception as e:\\n{indent_str} logging.error(f\\\"Error in plotting: {{e}}\\\")\"\n", + " j = i\n", + " while j < len(lines) and (j == i or lines[j].startswith(indent_str)):\n", + " j += 1\n", + " for k in range(i, j):\n", + " if lines[k].strip():\n", + " lines[k] = indent_str + \" \" + lines[k].lstrip()\n", + " lines.insert(i, try_line.rstrip())\n", + " lines.insert(j+1, except_line)\n", + " code = '\\n'.join(lines)\n", + " break\n", + " lines = code.split('\\n')\n", + " for i, line in enumerate(lines):\n", + " if \"print(\" in line and any(term in line for term in ['data[', '.iloc', '.loc', 'Series', 'DataFrame']):\n", + " lines[i] = line.replace(\"print(\", \"print(safe_format(\")\n", + " if \"))\" not in lines[i] and \"),\" in lines[i]:\n", + " lines[i] = lines[i].replace(\"),\", \")),\", 1)\n", + " elif \"))\" not in lines[i] and \")\" in lines[i]:\n", + " lines[i] = lines[i].replace(\")\", \"))\", 1)\n", + " code = '\\n'.join(lines)\n", + " return code\n" + ] + }, + { + "cell_type": "code", + "execution_count": 114, + "metadata": {}, + "outputs": [], + "source": [ + "def run_python(code):\n", + " # Create a completely separate namespace for execution\n", + " namespace = {\n", + " '__name__': '__main__',\n", + " '__builtins__': __builtins__\n", + " }\n", + " \n", + " # Modify the code to use a non-interactive matplotlib backend\n", + " # and fix pandas formatting issues\n", + " modified_code = \"\"\"\n", + "import matplotlib\n", + "matplotlib.use('Agg') # Use non-interactive backend\n", + "\n", + "# Import yfinance without setting timeout (not available in all versions)\n", + "import yfinance as yf\n", + "\n", + "# Configure logging to show in the output\n", + "import logging\n", + "logging.basicConfig(\n", + " level=logging.INFO,\n", + " format='[%(asctime)s] %(levelname)s: %(message)s',\n", + " datefmt='%H:%M:%S'\n", + ")\n", + "\n", + "# Fix pandas formatting issues\n", + "import pandas as pd\n", + "pd.set_option('display.float_format', '{:.5f}'.format)\n", + "\n", + "# Override print to ensure it flushes immediately\n", + "import builtins\n", + "original_print = builtins.print\n", + "def custom_print(*args, **kwargs):\n", + " result = original_print(*args, **kwargs)\n", + " import sys\n", + " sys.stdout.flush()\n", + " return result\n", + "builtins.print = custom_print\n", + "\n", + "# Helper function to safely format pandas objects\n", + "def safe_format(obj):\n", + " if isinstance(obj, (pd.Series, pd.DataFrame)):\n", + " return str(obj)\n", + " else:\n", + " return obj\n", + "\"\"\"\n", + " \n", + " # Add the user's code\n", + " modified_code += \"\\n\" + code\n", + " \n", + " # Capture all output\n", + " output_buffer = io.StringIO()\n", + " \n", + " # Save original stdout and redirect to our buffer\n", + " original_stdout = sys.stdout\n", + " sys.stdout = output_buffer\n", + " \n", + " # Add timestamp for execution start\n", + " print(f\"[{time.strftime('%H:%M:%S')}] Executing code...\")\n", + " \n", + " try:\n", + " # Execute the modified code\n", + " exec(modified_code, namespace)\n", + " print(f\"\\n[{time.strftime('%H:%M:%S')}] Execution completed successfully.\")\n", + " \n", + " except ModuleNotFoundError as e:\n", + " missing_module = str(e).split(\"'\")[1]\n", + " print(f\"\\nError: Missing module '{missing_module}'. Click 'Install Dependencies' to install it.\")\n", + " namespace[\"__missing_module__\"] = missing_module\n", + " \n", + " except Exception as e:\n", + " print(f\"\\n[{time.strftime('%H:%M:%S')}] Error during execution: {str(e)}\")\n", + " import traceback\n", + " print(traceback.format_exc())\n", + " \n", + " finally:\n", + " # Restore original stdout\n", + " sys.stdout = original_stdout\n", + " \n", + " # Return the captured output\n", + " return output_buffer.getvalue()\n", + "\n", + "def install_dependencies(code):\n", + " import re\n", + " import subprocess\n", + " \n", + " import_pattern = r'(?:from|import)\\s+([a-zA-Z0-9_]+)(?:\\s+(?:import|as))?'\n", + " imports = re.findall(import_pattern, code)\n", + " \n", + " std_libs = ['os', 'sys', 'io', 'time', 'datetime', 'random', 'math', 're', 'json', \n", + " 'collections', 'itertools', 'functools', 'operator', 'pathlib', 'typing']\n", + " \n", + " modules_to_install = [module for module in imports if module not in std_libs]\n", + " \n", + " if not modules_to_install:\n", + " return \"No external dependencies found to install.\"\n", + " \n", + " results = []\n", + " for module in modules_to_install:\n", + " try:\n", + " result = subprocess.run(\n", + " [sys.executable, \"-m\", \"pip\", \"install\", module],\n", + " capture_output=True,\n", + " text=True,\n", + " check=False\n", + " )\n", + " \n", + " if result.returncode == 0:\n", + " results.append(f\"Successfully installed {module}\")\n", + " else:\n", + " results.append(f\"Failed to install {module}: {result.stderr}\")\n", + " except Exception as e:\n", + " results.append(f\"Error installing {module}: {str(e)}\")\n", + " \n", + " return \"\\n\".join(results)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 109, + "metadata": {}, + "outputs": [], + "source": [ + "trading_strategies = [\n", + " {\n", + " \"name\": \"Moving Average Crossover\",\n", + " \"description\": \"Moving Average Crossover strategy for S&P 500 stocks. Buy when the 20-day moving average crosses above the 50-day moving average, and sell when it crosses below.\",\n", + " \"buy_signal\": \"20-day MA crosses above 50-day MA\",\n", + " \"sell_signal\": \"20-day MA crosses below 50-day MA\",\n", + " \"timeframe\": \"Daily\",\n", + " \"risk_level\": \"Medium\"\n", + " },\n", + " {\n", + " \"name\": \"RSI Mean Reversion\",\n", + " \"description\": \"Mean reversion strategy that buys stocks when RSI falls below 30 (oversold) and sells when RSI rises above 70 (overbought).\",\n", + " \"buy_signal\": \"RSI below 30 (oversold)\",\n", + " \"sell_signal\": \"RSI above 70 (overbought)\",\n", + " \"timeframe\": \"Daily\",\n", + " \"risk_level\": \"Medium\"\n", + " },\n", + " {\n", + " \"name\": \"Momentum Strategy\",\n", + " \"description\": \"Momentum strategy that buys the top 5 performing stocks from the Dow Jones Industrial Average over the past month and rebalances monthly.\",\n", + " \"buy_signal\": \"Stock in top 5 performers over past month\",\n", + " \"sell_signal\": \"Stock no longer in top 5 performers at rebalance\",\n", + " \"timeframe\": \"Monthly\",\n", + " \"risk_level\": \"High\"\n", + " },\n", + " {\n", + " \"name\": \"Pairs Trading\",\n", + " \"description\": \"Pairs trading strategy that identifies correlated stock pairs and trades on the divergence and convergence of their price relationship.\",\n", + " \"buy_signal\": \"Pairs ratio deviates 2+ standard deviations below mean\",\n", + " \"sell_signal\": \"Pairs ratio returns to mean or exceeds mean\",\n", + " \"timeframe\": \"Daily\",\n", + " \"risk_level\": \"Medium-High\"\n", + " },\n", + " {\n", + " \"name\": \"Bollinger Band Breakout\",\n", + " \"description\": \"Volatility breakout strategy that buys when a stock breaks out of its upper Bollinger Band and sells when it reverts to the mean.\",\n", + " \"buy_signal\": \"Price breaks above upper Bollinger Band (2 std dev)\",\n", + " \"sell_signal\": \"Price reverts to middle Bollinger Band (SMA)\",\n", + " \"timeframe\": \"Daily\",\n", + " \"risk_level\": \"High\"\n", + " },\n", + " {\n", + " \"name\": \"MACD Crossover\",\n", + " \"description\": \"MACD crossover strategy that buys when the MACD line crosses above the signal line and sells when it crosses below.\",\n", + " \"buy_signal\": \"MACD line crosses above signal line\",\n", + " \"sell_signal\": \"MACD line crosses below signal line\",\n", + " \"timeframe\": \"Daily\",\n", + " \"risk_level\": \"Medium\"\n", + " },\n", + " {\n", + " \"name\": \"Golden Cross\",\n", + " \"description\": \"Golden Cross strategy that buys when the 50-day moving average crosses above the 200-day moving average and sells on the Death Cross (opposite).\",\n", + " \"buy_signal\": \"50-day MA crosses above 200-day MA\",\n", + " \"sell_signal\": \"50-day MA crosses below 200-day MA\",\n", + " \"timeframe\": \"Daily\",\n", + " \"risk_level\": \"Low\"\n", + " }\n", + "]\n", + "\n", + "sample_strategies = [strategy[\"description\"] for strategy in trading_strategies]\n" + ] + }, + { + "cell_type": "code", + "execution_count": 110, + "metadata": {}, + "outputs": [], + "source": [ + "default_description = \"\"\"\n", + "Create a moving average crossover strategy with the following specifications:\n", + "- Use yfinance to download historical data for a list of stocks (AAPL, MSFT, AMZN, GOOGL, META)\n", + "- Calculate 20-day and 50-day moving averages\n", + "- Generate buy signals when the 20-day MA crosses above the 50-day MA\n", + "- Generate sell signals when the 20-day MA crosses below the 50-day MA\n", + "- Implement a simple backtesting framework to evaluate the strategy\n", + "- Calculate performance metrics: total return, annualized return, Sharpe ratio, max drawdown\n", + "- Visualize the equity curve, buy/sell signals, and moving averages\n", + "\"\"\"\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2025-10-22 18:30:21,233 - INFO - HTTP Request: GET http://127.0.0.1:7875/gradio_api/startup-events \"HTTP/1.1 200 OK\"\n", + "2025-10-22 18:30:21,238 - INFO - HTTP Request: HEAD http://127.0.0.1:7875/ \"HTTP/1.1 200 OK\"\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "* Running on local URL: http://127.0.0.1:7875\n", + "* To create a public link, set `share=True` in `launch()`.\n" + ] + }, + { + "data": { + "text/html": [ + "
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [] + }, + "execution_count": 115, + "metadata": {}, + "output_type": "execute_result" + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2025-10-22 18:30:24,092 - INFO - HTTP Request: GET https://api.gradio.app/pkg-version \"HTTP/1.1 200 OK\"\n", + "2025-10-22 18:31:03,437 - INFO - HTTP Request: POST https://api.openai.com/v1/chat/completions \"HTTP/1.1 200 OK\"\n", + "2025-10-22 18:31:15,743 - INFO - HTTP Request: POST https://api.openai.com/v1/chat/completions \"HTTP/1.1 200 OK\"\n", + "2025-10-22 18:31:28,425 - ERROR - Error fetching data: periods must be a number, got 2025-10-22 18:31:28.425210\n", + "2025-10-22 18:31:28,429 - INFO - Synthetic data generated for tickers: AAPL\n", + "2025-10-22 18:31:28,432 - INFO - Moving averages calculated with windows 20 and 50\n", + "2025-10-22 18:31:28,434 - INFO - Signals generated based on moving average crossover\n", + "2025-10-22 18:31:28,438 - INFO - Performance calculated\n", + "2025-10-22 18:31:28,438 - INFO - Total Return: -0.010752455100331848, Sharpe Ratio: 0.18162435507214664, Max Drawdown: -0.19919271751608258\n", + "2025-10-22 18:32:28,496 - INFO - HTTP Request: POST https://api.openai.com/v1/chat/completions \"HTTP/1.1 200 OK\"\n", + "2025-10-22 18:32:38,626 - INFO - HTTP Request: POST https://api.openai.com/v1/chat/completions \"HTTP/1.1 200 OK\"\n", + "2025-10-22 18:32:47,779 - ERROR - Error fetching data from yfinance: name 'start_date' is not defined. Using synthetic data.\n", + "2025-10-22 18:33:23,647 - INFO - HTTP Request: POST https://api.openai.com/v1/chat/completions \"HTTP/1.1 200 OK\"\n", + "2025-10-22 18:33:37,829 - INFO - HTTP Request: POST https://api.openai.com/v1/chat/completions \"HTTP/1.1 200 OK\"\n" + ] + } + ], + "source": [ + "with gr.Blocks(css=CSS, theme=gr.themes.Monochrome(), title=\"Trading Code Generator\") as ui:\n", + " with gr.Row():\n", + " gr.HTML(\"

Trading Strategy Code Generator

\")\n", + " \n", + " with gr.Row():\n", + " # Left column - Controls\n", + " with gr.Column(scale=1):\n", + " strategy_dropdown = gr.Dropdown(\n", + " label=\"Select Trading Strategy\",\n", + " choices=[strategy[\"name\"] for strategy in trading_strategies],\n", + " value=trading_strategies[0][\"name\"]\n", + " )\n", + " \n", + " with gr.Accordion(\"Strategy Details\", open=False):\n", + " strategy_info = gr.JSON(\n", + " value=trading_strategies[0]\n", + " )\n", + " \n", + " model = gr.Dropdown(\n", + " label=\"Select Model\",\n", + " choices=models,\n", + " value=models[0]\n", + " )\n", + " \n", + " description = gr.TextArea(\n", + " label=\"Strategy Description (Edit to customize)\",\n", + " value=trading_strategies[0][\"description\"],\n", + " lines=4\n", + " )\n", + " \n", + " with gr.Row():\n", + " generate = gr.Button(\"Generate Code\", variant=\"primary\", size=\"sm\")\n", + " run = gr.Button(\"Run Code\", size=\"sm\")\n", + " install_deps = gr.Button(\"Install Dependencies\", size=\"sm\")\n", + " \n", + " # Right column - Code and Output\n", + " with gr.Column(scale=2):\n", + " trading_code = gr.Code(\n", + " label=\"Generated Trading Code\",\n", + " value=\"\",\n", + " language=\"python\",\n", + " lines=20,\n", + " elem_id=\"code-block\",\n", + " show_label=True\n", + " )\n", + " \n", + " output = gr.TextArea(\n", + " label=\"Execution Output\",\n", + " lines=8,\n", + " elem_classes=[\"trading-out\"]\n", + " )\n", + " \n", + " def update_strategy_info(strategy_name):\n", + " selected = next((s for s in trading_strategies if s[\"name\"] == strategy_name), None)\n", + " if selected:\n", + " return selected, selected[\"description\"]\n", + " return trading_strategies[0], trading_strategies[0][\"description\"]\n", + " \n", + " strategy_dropdown.change(\n", + " fn=update_strategy_info,\n", + " inputs=strategy_dropdown,\n", + " outputs=[strategy_info, description]\n", + " )\n", + " \n", + " # Function to show validation results when generating code\n", + " def generate_with_validation(model, description):\n", + " # Always use GPT-4o for better code quality\n", + " code = generate_trading_code(model, description, force_gpt4=True)\n", + " is_valid, issues = validate_code(code)\n", + " \n", + " validation_message = \"\"\n", + " if is_valid:\n", + " validation_message = \"Code validation passed ✓\"\n", + " else:\n", + " validation_message = \"Code validation warnings:\\n\" + \"\\n\".join([f\"- {issue}\" for issue in issues])\n", + " \n", + " return code, validation_message\n", + " \n", + " generate.click(\n", + " fn=generate_with_validation,\n", + " inputs=[model, description],\n", + " outputs=[trading_code, output]\n", + " )\n", + " \n", + " run.click(\n", + " fn=run_python,\n", + " inputs=[trading_code],\n", + " outputs=[output]\n", + " )\n", + " \n", + " install_deps.click(\n", + " fn=install_dependencies,\n", + " inputs=[trading_code],\n", + " outputs=[output]\n", + " )\n", + "\n", + "ui.launch(inbrowser=True)\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Testing the Trading Code Generator\n", + "\n", + "Let's test the trading code generator with a specific strategy description and model.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "test_description = \"\"\"\n", + "Create a simple RSI-based mean reversion strategy:\n", + "- Use AAPL stock data for the past 2 years\n", + "- Calculate the 14-day RSI indicator\n", + "- Buy when RSI falls below 30 (oversold)\n", + "- Sell when RSI rises above 70 (overbought)\n", + "- Include visualization of entry/exit points\n", + "- Calculate performance metrics\n", + "\"\"\"\n", + "\n", + "test_model = \"gpt-3.5-turbo\"\n", + "\n", + "generated_code = generate_trading_code(test_model, test_description)\n", + "print(\"Generated trading code:\")\n", + "print(generated_code)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "try:\n", + " output = run_python(generated_code)\n", + " print(output)\n", + "except Exception as e:\n", + " print(f\"Error running the generated code: {e}\")\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Fixed version of the code\n", + "fixed_code = \"\"\"\n", + "import yfinance as yf\n", + "import pandas as pd\n", + "import numpy as np\n", + "import matplotlib.pyplot as plt\n", + "import matplotlib.dates as mdates\n", + "from datetime import datetime, timedelta\n", + "import logging\n", + "\n", + "# Set up logging\n", + "logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')\n", + "\n", + "def calculate_moving_averages(data, short_window=20, long_window=50):\n", + " \\\"\\\"\\\"\n", + " Calculate short and long term moving averages\n", + " \\\"\\\"\\\"\n", + " data['Short_MA'] = data['Close'].rolling(window=short_window, min_periods=1).mean()\n", + " data['Long_MA'] = data['Close'].rolling(window=long_window, min_periods=1).mean()\n", + " return data\n", + "\n", + "def generate_signals(data, short_window=20, long_window=50):\n", + " \\\"\\\"\\\"\n", + " Generate buy/sell signals based on moving average crossover strategy\n", + " \\\"\\\"\\\"\n", + " data['Signal'] = 0\n", + " data['Signal'][short_window:] = np.where(\n", + " data['Short_MA'][short_window:] > data['Long_MA'][short_window:], 1, -1)\n", + " data['Position'] = data['Signal'].shift(1)\n", + " data['Position'].fillna(0, inplace=True) # Fill NaN values with 0\n", + " return data\n", + "\n", + "def backtest_strategy(data):\n", + " \\\"\\\"\\\"\n", + " Backtest the trading strategy and calculate performance metrics\n", + " \\\"\\\"\\\"\n", + " data['Returns'] = data['Close'].pct_change()\n", + " data['Strategy_Returns'] = data['Returns'] * data['Position']\n", + " \n", + " # Replace NaN values with 0\n", + " data['Strategy_Returns'].fillna(0, inplace=True)\n", + "\n", + " cumulative_returns = (1 + data['Strategy_Returns']).cumprod()\n", + " \n", + " # Calculate metrics\n", + " total_return = cumulative_returns.iloc[-1] - 1\n", + " sharpe_ratio = np.sqrt(252) * (data['Strategy_Returns'].mean() / data['Strategy_Returns'].std())\n", + " max_drawdown = ((cumulative_returns / cumulative_returns.cummax()) - 1).min()\n", + "\n", + " metrics = {\n", + " 'Total Return': total_return,\n", + " 'Sharpe Ratio': sharpe_ratio,\n", + " 'Max Drawdown': max_drawdown\n", + " }\n", + "\n", + " return cumulative_returns, metrics\n", + "\n", + "def plot_results(data, cumulative_returns, ticker):\n", + " \\\"\\\"\\\"\n", + " Plot the performance of the trading strategy\n", + " \\\"\\\"\\\"\n", + " fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(14, 10), gridspec_kw={'height_ratios': [2, 1]})\n", + " \n", + " # Price and MA plot\n", + " ax1.plot(data.index, data['Close'], label='Close Price')\n", + " ax1.plot(data.index, data['Short_MA'], label='20-day MA', alpha=0.7)\n", + " ax1.plot(data.index, data['Long_MA'], label='50-day MA', alpha=0.7)\n", + " \n", + " # Add buy/sell signals\n", + " buy_signals = data[data['Signal'] > data['Signal'].shift(1)]\n", + " sell_signals = data[data['Signal'] < data['Signal'].shift(1)]\n", + " \n", + " ax1.scatter(buy_signals.index, buy_signals['Close'], marker='^', color='green', s=100, label='Buy Signal')\n", + " ax1.scatter(sell_signals.index, sell_signals['Close'], marker='v', color='red', s=100, label='Sell Signal')\n", + " \n", + " ax1.set_title(f'Moving Average Crossover Strategy on {ticker}')\n", + " ax1.set_ylabel('Price ($)')\n", + " ax1.legend(loc='best')\n", + " ax1.grid(True)\n", + " \n", + " # Returns plot\n", + " ax2.plot(cumulative_returns.index, cumulative_returns, label='Cumulative Strategy Returns', color='blue')\n", + " ax2.set_title('Cumulative Returns')\n", + " ax2.set_xlabel('Date')\n", + " ax2.set_ylabel('Returns')\n", + " ax2.legend(loc='best')\n", + " ax2.grid(True)\n", + " \n", + " plt.tight_layout()\n", + " plt.show()\n", + "\n", + "if __name__ == \\\"__main__\\\":\n", + " # User inputs\n", + " ticker = 'SPY' # Example: S&P 500 ETF\n", + " start_date = (datetime.now() - timedelta(days=365*2)).strftime('%Y-%m-%d')\n", + " end_date = datetime.now().strftime('%Y-%m-%d')\n", + "\n", + " # Strategy parameters\n", + " short_window = 20\n", + " long_window = 50\n", + "\n", + " # Fetch data\n", + " try:\n", + " logging.info(f\\\"Fetching data for {ticker} from {start_date} to {end_date}...\\\")\n", + " stock_data = yf.download(ticker, start=start_date, end=end_date)\n", + " logging.info(f\\\"Data fetched successfully. Got {len(stock_data)} data points.\\\")\n", + " except Exception as e:\n", + " logging.error(f\\\"Failed to fetch data: {e}\\\")\n", + " raise SystemExit(e)\n", + "\n", + " try:\n", + " # Preprocess and generate signals\n", + " stock_data = calculate_moving_averages(stock_data, short_window, long_window)\n", + " stock_data = generate_signals(stock_data, short_window, long_window)\n", + "\n", + " # Backtest the strategy\n", + " cumulative_returns, metrics = backtest_strategy(stock_data)\n", + "\n", + " # Display metrics\n", + " for key, value in metrics.items():\n", + " logging.info(f\\\"{key}: {value:.4f}\\\")\n", + "\n", + " # Plot results\n", + " plot_results(stock_data, cumulative_returns, ticker)\n", + " except Exception as e:\n", + " logging.error(f\\\"Error while executing strategy: {e}\\\")\n", + "\"\"\"\n", + "\n", + "# Display the fixed code\n", + "print(\"Fixed code:\")\n", + "print(fixed_code)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Run the fixed code\n", + "output = run_python(fixed_code)\n", + "print(output)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Let's also update our system_prompt to ensure the generated code works properly\n", + "\n", + "system_prompt = \"\"\"\n", + "You are an expert algorithmic trading code generator. Your task is to generate Python code for trading strategies based on user requirements.\n", + "The code should be well-structured, efficient, and ready to run in a simulated environment.\n", + "\n", + "The generated code should:\n", + "1. Use the yfinance library for fetching stock data\n", + "2. Implement the specified trading strategy\n", + "3. Include proper error handling and logging\n", + "4. Include visualization of the strategy performance with clear buy/sell signals\n", + "5. Calculate and display relevant metrics (returns, Sharpe ratio, drawdown, etc.)\n", + "6. Handle NaN values and edge cases properly\n", + "7. Include informative print statements or logging to show progress\n", + "\n", + "IMPORTANT: Make sure all variables are properly defined before use, especially in functions.\n", + "Always pass necessary parameters between functions rather than relying on global variables.\n", + "\n", + "Respond only with Python code. Do not provide any explanation other than occasional comments in the code.\n", + "\"\"\"\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Let's test the updated system prompt with a simple strategy\n", + "\n", + "test_description_2 = \"\"\"\n", + "Create a simple Bollinger Bands strategy:\n", + "- Use AAPL stock data for the past 1 year\n", + "- Calculate Bollinger Bands with 20-day SMA and 2 standard deviations\n", + "- Buy when price touches the lower band\n", + "- Sell when price touches the upper band\n", + "- Include visualization of entry/exit points\n", + "- Calculate performance metrics\n", + "\"\"\"\n", + "\n", + "test_model = \"gpt-3.5-turbo\"\n", + "generated_code_2 = generate_trading_code(test_model, test_description_2)\n", + "print(\"Generated trading code with updated prompt:\")\n", + "print(generated_code_2)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Let's test the run function with live logging\n", + "\n", + "test_code = \"\"\"\n", + "import yfinance as yf\n", + "import pandas as pd\n", + "import numpy as np\n", + "import matplotlib.pyplot as plt\n", + "from datetime import datetime, timedelta\n", + "\n", + "# Define ticker and date range\n", + "ticker = 'AAPL'\n", + "end_date = datetime.now()\n", + "start_date = end_date - timedelta(days=365)\n", + "\n", + "# Download data\n", + "print(f\"Downloading data for {ticker}...\")\n", + "data = yf.download(ticker, start=start_date, end=end_date)\n", + "print(f\"Downloaded {len(data)} rows of data\")\n", + "\n", + "# Calculate RSI\n", + "print(\"Calculating RSI...\")\n", + "delta = data['Close'].diff()\n", + "gain = delta.where(delta > 0, 0)\n", + "loss = -delta.where(delta < 0, 0)\n", + "avg_gain = gain.rolling(window=14).mean()\n", + "avg_loss = loss.rolling(window=14).mean()\n", + "rs = avg_gain / avg_loss\n", + "data['RSI'] = 100 - (100 / (1 + rs))\n", + "\n", + "# Generate signals\n", + "print(\"Generating trading signals...\")\n", + "data['Signal'] = 0\n", + "data.loc[data['RSI'] < 30, 'Signal'] = 1 # Buy signal\n", + "data.loc[data['RSI'] > 70, 'Signal'] = -1 # Sell signal\n", + "\n", + "# Count signals\n", + "buy_signals = len(data[data['Signal'] == 1])\n", + "sell_signals = len(data[data['Signal'] == -1])\n", + "print(f\"Generated {buy_signals} buy signals and {sell_signals} sell signals\")\n", + "\n", + "# Print sample of the data\n", + "print(\"\\\\nSample of the processed data:\")\n", + "print(data[['Close', 'RSI', 'Signal']].tail())\n", + "\n", + "print(\"\\\\nAnalysis complete!\")\n", + "\"\"\"\n", + "\n", + "output = run_python(test_code)\n", + "print(output)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Let's test the improved run function again\n", + "\n", + "test_code_2 = \"\"\"\n", + "import yfinance as yf\n", + "import pandas as pd\n", + "import numpy as np\n", + "import matplotlib.pyplot as plt\n", + "from datetime import datetime, timedelta\n", + "\n", + "# Define ticker and date range\n", + "ticker = 'AAPL'\n", + "end_date = datetime.now()\n", + "start_date = end_date - timedelta(days=365)\n", + "\n", + "# Download data\n", + "print(f\"Downloading data for {ticker}...\")\n", + "data = yf.download(ticker, start=start_date, end=end_date, progress=False)\n", + "print(f\"Downloaded {len(data)} rows of data\")\n", + "\n", + "# Calculate RSI\n", + "print(\"Calculating RSI...\")\n", + "delta = data['Close'].diff()\n", + "gain = delta.where(delta > 0, 0)\n", + "loss = -delta.where(delta < 0, 0)\n", + "avg_gain = gain.rolling(window=14).mean()\n", + "avg_loss = loss.rolling(window=14).mean()\n", + "rs = avg_gain / avg_loss\n", + "data['RSI'] = 100 - (100 / (1 + rs))\n", + "\n", + "# Generate signals\n", + "print(\"Generating trading signals...\")\n", + "data['Signal'] = 0\n", + "data.loc[data['RSI'] < 30, 'Signal'] = 1 # Buy signal\n", + "data.loc[data['RSI'] > 70, 'Signal'] = -1 # Sell signal\n", + "\n", + "# Count signals\n", + "buy_signals = len(data[data['Signal'] == 1])\n", + "sell_signals = len(data[data['Signal'] == -1])\n", + "print(f\"Generated {buy_signals} buy signals and {sell_signals} sell signals\")\n", + "\n", + "# Print sample of the data\n", + "print(\"\\\\nSample of the processed data:\")\n", + "print(data[['Close', 'RSI', 'Signal']].tail())\n", + "\n", + "print(\"\\\\nAnalysis complete!\")\n", + "\"\"\"\n", + "\n", + "output = run_python(test_code_2)\n", + "print(output)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Test the completely rewritten run function\n", + "\n", + "simple_test = \"\"\"\n", + "print(\"Hello from the trading code generator!\")\n", + "print(\"Testing output capture...\")\n", + "\n", + "# Simulate some data processing\n", + "import numpy as np\n", + "data = np.random.rand(5, 3)\n", + "print(\"Generated random data:\")\n", + "print(data)\n", + "\n", + "# Show a calculation\n", + "result = np.mean(data, axis=0)\n", + "print(\"Mean of each column:\")\n", + "print(result)\n", + "\n", + "print(\"Test complete!\")\n", + "\"\"\"\n", + "\n", + "output = run_python(simple_test)\n", + "print(\"Output from execution:\")\n", + "print(output)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Test the improved code generation with a strategy that typically causes formatting issues\n", + "\n", + "test_description_3 = \"\"\"\n", + "Create a simple trading strategy that:\n", + "- Uses AAPL stock data for the past year\n", + "- Calculates both RSI and Bollinger Bands\n", + "- Buys when price is below lower Bollinger Band AND RSI is below 30\n", + "- Sells when price is above upper Bollinger Band OR RSI is above 70\n", + "- Includes proper error handling for all calculations\n", + "- Visualizes the entry/exit points and performance\n", + "\"\"\"\n", + "\n", + "test_model = \"gpt-3.5-turbo\"\n", + "generated_code_3 = generate_trading_code(test_model, test_description_3)\n", + "print(\"Generated trading code with enhanced validation:\")\n", + "print(generated_code_3)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Let's test running the generated code\n", + "\n", + "output = run_python(generated_code_3)\n", + "print(\"Execution output:\")\n", + "print(output)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Test the improved run function with a simple example that prints output\n", + "\n", + "simple_test_2 = \"\"\"\n", + "print(\"This is a test of the output capture system\")\n", + "print(\"Line 1 of output\")\n", + "print(\"Line 2 of output\")\n", + "\n", + "# Import and use numpy\n", + "import numpy as np\n", + "data = np.random.rand(3, 3)\n", + "print(\"Random matrix:\")\n", + "print(data)\n", + "\n", + "# Create a simple plot\n", + "import matplotlib.pyplot as plt\n", + "plt.figure(figsize=(8, 4))\n", + "plt.plot([1, 2, 3, 4], [10, 20, 25, 30], 'ro-')\n", + "plt.title('Simple Plot')\n", + "plt.xlabel('X axis')\n", + "plt.ylabel('Y axis')\n", + "plt.grid(True)\n", + "plt.savefig('simple_plot.png') # Save instead of showing\n", + "print(\"Plot saved to simple_plot.png\")\n", + "\n", + "print(\"Test complete!\")\n", + "\"\"\"\n", + "\n", + "output = run_python(simple_test_2)\n", + "print(\"Output from execution:\")\n", + "print(output)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Test with a simpler example that won't time out\n", + "\n", + "simple_test_3 = \"\"\"\n", + "import numpy as np\n", + "import pandas as pd\n", + "import matplotlib.pyplot as plt\n", + "import logging\n", + "\n", + "# Set up logging\n", + "logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')\n", + "\n", + "# Generate synthetic stock data\n", + "def generate_stock_data(days=252, volatility=0.01):\n", + " logging.info(f\"Generating {days} days of synthetic stock data\")\n", + " np.random.seed(42)\n", + " price = 100\n", + " prices = [price]\n", + " \n", + " for _ in range(days - 1):\n", + " change = np.random.normal(0, volatility)\n", + " price *= (1 + change)\n", + " prices.append(price)\n", + " \n", + " dates = pd.date_range(end=pd.Timestamp.today(), periods=days)\n", + " df = pd.DataFrame({\n", + " 'Close': prices,\n", + " 'Open': [p * (1 - volatility/2) for p in prices],\n", + " 'High': [p * (1 + volatility) for p in prices],\n", + " 'Low': [p * (1 - volatility) for p in prices],\n", + " 'Volume': [np.random.randint(100000, 10000000) for _ in range(days)]\n", + " }, index=dates)\n", + " \n", + " logging.info(f\"Generated data with shape {df.shape}\")\n", + " return df\n", + "\n", + "# Calculate RSI\n", + "def calculate_rsi(data, window=14):\n", + " logging.info(f\"Calculating RSI with {window}-day window\")\n", + " delta = data['Close'].diff()\n", + " gain = delta.where(delta > 0, 0).rolling(window=window, min_periods=1).mean()\n", + " loss = -delta.where(delta < 0, 0).rolling(window=window, min_periods=1).mean()\n", + " \n", + " rs = gain / loss\n", + " rsi = 100 - (100 / (1 + rs))\n", + " return rsi\n", + "\n", + "# Main function\n", + "if __name__ == \"__main__\":\n", + " # Generate data\n", + " data = generate_stock_data()\n", + " \n", + " # Calculate RSI\n", + " data['RSI'] = calculate_rsi(data)\n", + " \n", + " # Generate signals\n", + " logging.info(\"Generating trading signals\")\n", + " data['Signal'] = 0\n", + " data.loc[data['RSI'] < 30, 'Signal'] = 1 # Buy signal\n", + " data.loc[data['RSI'] > 70, 'Signal'] = -1 # Sell signal\n", + " \n", + " # Count signals\n", + " buy_signals = len(data[data['Signal'] == 1])\n", + " sell_signals = len(data[data['Signal'] == -1])\n", + " logging.info(f\"Generated {buy_signals} buy signals and {sell_signals} sell signals\")\n", + " \n", + " # Plot the results\n", + " logging.info(\"Creating visualization\")\n", + " fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(12, 8), gridspec_kw={'height_ratios': [3, 1]})\n", + " \n", + " # Price chart\n", + " ax1.plot(data.index, data['Close'], label='Close Price')\n", + " \n", + " # Add buy/sell signals\n", + " buy_points = data[data['Signal'] == 1]\n", + " sell_points = data[data['Signal'] == -1]\n", + " \n", + " ax1.scatter(buy_points.index, buy_points['Close'], marker='^', color='g', s=100, label='Buy')\n", + " ax1.scatter(sell_points.index, sell_points['Close'], marker='v', color='r', s=100, label='Sell')\n", + " \n", + " ax1.set_title('Stock Price with RSI Signals')\n", + " ax1.set_ylabel('Price')\n", + " ax1.legend()\n", + " ax1.grid(True)\n", + " \n", + " # RSI chart\n", + " ax2.plot(data.index, data['RSI'], color='purple', label='RSI')\n", + " ax2.axhline(y=70, color='r', linestyle='--', alpha=0.5)\n", + " ax2.axhline(y=30, color='g', linestyle='--', alpha=0.5)\n", + " ax2.set_title('RSI Indicator')\n", + " ax2.set_ylabel('RSI')\n", + " ax2.set_ylim(0, 100)\n", + " ax2.grid(True)\n", + " \n", + " plt.tight_layout()\n", + " plt.savefig('rsi_strategy.png')\n", + " logging.info(\"Plot saved to rsi_strategy.png\")\n", + " \n", + " # Print sample of data\n", + " logging.info(\"Sample of the processed data:\")\n", + " print(data[['Close', 'RSI', 'Signal']].tail())\n", + " \n", + " logging.info(\"Analysis complete!\")\n", + "\"\"\"\n", + "\n", + "output = run_python(simple_test_3)\n", + "print(\"Output from execution:\")\n", + "print(output)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Test the enhanced code generation with GPT-4o\n", + "\n", + "test_description_4 = \"\"\"\n", + "Create a trading strategy that:\n", + "- Uses both MACD and RSI indicators\n", + "- Buys when MACD crosses above signal line AND RSI is below 40\n", + "- Sells when MACD crosses below signal line OR RSI is above 70\n", + "- Includes proper visualization with buy/sell signals\n", + "- Uses synthetic data if API calls fail\n", + "- Calculates performance metrics including Sharpe ratio and max drawdown\n", + "\"\"\"\n", + "\n", + "print(\"Generating trading code with GPT-4o...\")\n", + "generated_code_4 = generate_trading_code(\"gpt-4o\", test_description_4, force_gpt4=True)\n", + "print(\"Code generation complete. Validating...\")\n", + "is_valid, issues = validate_code(generated_code_4)\n", + "\n", + "if issues:\n", + " print(f\"Validation found {len(issues)} issues:\")\n", + " for issue in issues:\n", + " print(f\"- {issue}\")\n", + "else:\n", + " print(\"Code validation passed ✓\")\n", + "\n", + "print(\"\\nGenerated code snippet (first 20 lines):\")\n", + "print(\"\\n\".join(generated_code_4.split(\"\\n\")[:20]))\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Let's run the generated code to test it\n", + "\n", + "output = run_python(generated_code_4)\n", + "print(\"Execution output:\")\n", + "print(output)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Let's test again with the fixed timeout setting\n", + "\n", + "test_description_5 = \"\"\"\n", + "Create a simple trading strategy that:\n", + "- Uses synthetic data generation to avoid API timeouts\n", + "- Implements a simple moving average crossover (5-day and 20-day)\n", + "- Includes proper visualization with buy/sell signals\n", + "- Calculates basic performance metrics\n", + "\"\"\"\n", + "\n", + "print(\"Generating trading code with proper yfinance timeout settings...\")\n", + "generated_code_5 = generate_trading_code(\"gpt-4o\", test_description_5, force_gpt4=True)\n", + "print(\"Code generation complete.\")\n", + "\n", + "# Run the generated code\n", + "output = run_python(generated_code_5)\n", + "print(\"Execution output:\")\n", + "print(output)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Test with a simpler strategy that focuses on scatter plot safety\n", + "\n", + "test_description_6 = \"\"\"\n", + "Create a simple trading strategy that:\n", + "- Uses synthetic data generation only (no API calls)\n", + "- Implements a simple RSI-based strategy (buy when RSI < 30, sell when RSI > 70)\n", + "- Includes visualization with buy/sell signals using scatter plots\n", + "- Calculates basic performance metrics\n", + "- Uses proper error handling for all operations\n", + "\"\"\"\n", + "\n", + "print(\"Generating trading code with scatter plot safety...\")\n", + "generated_code_6 = generate_trading_code(\"gpt-4o\", test_description_6, force_gpt4=True)\n", + "print(\"Code generation complete.\")\n", + "\n", + "# Run the generated code\n", + "output = run_python(generated_code_6)\n", + "print(\"Execution output:\")\n", + "print(output)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Test with a fixed example that properly handles pandas formatting and scatter plots\n", + "\n", + "test_fixed_code = \"\"\"\n", + "import yfinance as yf\n", + "import pandas as pd\n", + "import numpy as np\n", + "import matplotlib.pyplot as plt\n", + "import logging\n", + "from datetime import datetime, timedelta\n", + "\n", + "# Configure logging\n", + "logging.basicConfig(\n", + " level=logging.INFO,\n", + " format='[%(asctime)s] %(levelname)s: %(message)s',\n", + " datefmt='%H:%M:%S'\n", + ")\n", + "\n", + "# Helper function for safe formatting of pandas objects\n", + "def safe_format(obj):\n", + " if isinstance(obj, (pd.Series, pd.DataFrame)):\n", + " return str(obj)\n", + " return obj\n", + "\n", + "# Helper function to safely create scatter plots\n", + "def safe_scatter(ax, x, y, *args, **kwargs):\n", + " # Ensure x and y are the same length\n", + " if len(x) != len(y):\n", + " logging.warning(f\"Scatter plot inputs have different lengths: x={len(x)}, y={len(y)}\")\n", + " # Find the minimum length\n", + " min_len = min(len(x), len(y))\n", + " x = x[:min_len]\n", + " y = y[:min_len]\n", + " \n", + " # Check for empty arrays\n", + " if len(x) == 0 or len(y) == 0:\n", + " logging.warning(\"Empty arrays passed to scatter plot, skipping\")\n", + " return None\n", + " \n", + " return ax.scatter(x, y, *args, **kwargs)\n", + "\n", + "# Generate synthetic data\n", + "def generate_synthetic_data(ticker='AAPL', days=252, seed=42):\n", + " logging.info(f\"Generating synthetic data for {ticker}\")\n", + " np.random.seed(seed)\n", + " \n", + " # Generate price data\n", + " price = 100 # Starting price\n", + " prices = [price]\n", + " \n", + " for _ in range(days):\n", + " change = np.random.normal(0, 0.01) # 1% volatility\n", + " price *= (1 + change)\n", + " prices.append(price)\n", + " \n", + " # Create date range\n", + " end_date = datetime.now()\n", + " start_date = end_date - timedelta(days=days)\n", + " dates = pd.date_range(start=start_date, end=end_date, periods=len(prices))\n", + " \n", + " # Create DataFrame\n", + " df = pd.DataFrame({\n", + " 'Open': prices[:-1],\n", + " 'High': [p * 1.01 for p in prices[:-1]],\n", + " 'Low': [p * 0.99 for p in prices[:-1]],\n", + " 'Close': prices[1:],\n", + " 'Volume': [np.random.randint(1000000, 10000000) for _ in range(len(prices)-1)]\n", + " }, index=dates[:-1])\n", + " \n", + " logging.info(f\"Generated {len(df)} days of data for {ticker}\")\n", + " return df\n", + "\n", + "# Calculate RSI\n", + "def calculate_rsi(data, window=14):\n", + " logging.info(f\"Calculating RSI with {window}-day window\")\n", + " delta = data['Close'].diff()\n", + " gain = delta.where(delta > 0, 0)\n", + " loss = -delta.where(delta < 0, 0)\n", + " \n", + " avg_gain = gain.rolling(window=window, min_periods=1).mean()\n", + " avg_loss = loss.rolling(window=window, min_periods=1).mean()\n", + " \n", + " rs = avg_gain / avg_loss\n", + " rsi = 100 - (100 / (1 + rs))\n", + " return rsi\n", + "\n", + "# Generate signals\n", + "def generate_signals(data):\n", + " logging.info(\"Generating trading signals\")\n", + " data['Signal'] = 0\n", + " data.loc[data['RSI'] < 30, 'Signal'] = 1 # Buy signal\n", + " data.loc[data['RSI'] > 70, 'Signal'] = -1 # Sell signal\n", + " \n", + " # Count signals\n", + " buy_signals = len(data[data['Signal'] == 1])\n", + " sell_signals = len(data[data['Signal'] == -1])\n", + " logging.info(f\"Generated {buy_signals} buy signals and {sell_signals} sell signals\")\n", + " return data\n", + "\n", + "# Backtest strategy\n", + "def backtest_strategy(data):\n", + " logging.info(\"Backtesting strategy\")\n", + " data['Returns'] = data['Close'].pct_change()\n", + " data['Strategy'] = data['Signal'].shift(1) * data['Returns']\n", + " \n", + " # Replace NaN values\n", + " data['Strategy'].fillna(0, inplace=True)\n", + " \n", + " # Calculate cumulative returns\n", + " data['Cumulative'] = (1 + data['Strategy']).cumprod()\n", + " \n", + " # Calculate metrics\n", + " total_return = data['Cumulative'].iloc[-1] - 1\n", + " sharpe = np.sqrt(252) * data['Strategy'].mean() / data['Strategy'].std()\n", + " max_dd = (data['Cumulative'] / data['Cumulative'].cummax() - 1).min()\n", + " \n", + " logging.info(f\"Total Return: {total_return:.4f}\")\n", + " logging.info(f\"Sharpe Ratio: {sharpe:.4f}\")\n", + " logging.info(f\"Max Drawdown: {max_dd:.4f}\")\n", + " \n", + " return data\n", + "\n", + "# Visualize results\n", + "def visualize_results(data, ticker):\n", + " logging.info(\"Creating visualization\")\n", + " try:\n", + " fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(12, 10), gridspec_kw={'height_ratios': [3, 1]})\n", + " \n", + " # Price chart\n", + " ax1.plot(data.index, data['Close'], label='Close Price')\n", + " \n", + " # Add buy/sell signals\n", + " buy_points = data[data['Signal'] == 1]\n", + " sell_points = data[data['Signal'] == -1]\n", + " \n", + " # Use safe scatter to avoid \"x and y must be the same size\" error\n", + " if not buy_points.empty:\n", + " safe_scatter(ax1, buy_points.index, buy_points['Close'], marker='^', color='g', s=100, label='Buy')\n", + " \n", + " if not sell_points.empty:\n", + " safe_scatter(ax1, sell_points.index, sell_points['Close'], marker='v', color='r', s=100, label='Sell')\n", + " \n", + " ax1.set_title(f'RSI Strategy for {ticker}')\n", + " ax1.set_ylabel('Price')\n", + " ax1.legend()\n", + " ax1.grid(True)\n", + " \n", + " # RSI chart\n", + " ax2.plot(data.index, data['RSI'], color='purple', label='RSI')\n", + " ax2.axhline(y=70, color='r', linestyle='--', alpha=0.5)\n", + " ax2.axhline(y=30, color='g', linestyle='--', alpha=0.5)\n", + " ax2.set_title('RSI Indicator')\n", + " ax2.set_ylabel('RSI')\n", + " ax2.set_ylim(0, 100)\n", + " ax2.grid(True)\n", + " \n", + " plt.tight_layout()\n", + " plt.savefig('rsi_strategy.png')\n", + " logging.info(\"Plot saved to rsi_strategy.png\")\n", + " except Exception as e:\n", + " logging.error(f\"Error in visualization: {e}\")\n", + "\n", + "if __name__ == \"__main__\":\n", + " # Settings\n", + " ticker = 'AAPL'\n", + " days = 252 # One year of trading days\n", + " \n", + " # Generate data\n", + " data = generate_synthetic_data(ticker, days)\n", + " \n", + " # Calculate RSI\n", + " data['RSI'] = calculate_rsi(data)\n", + " \n", + " # Generate signals\n", + " data = generate_signals(data)\n", + " \n", + " # Backtest strategy\n", + " data = backtest_strategy(data)\n", + " \n", + " # Visualize results\n", + " visualize_results(data, ticker)\n", + " \n", + " # Print sample of data\n", + " logging.info(\"Sample of the processed data:\")\n", + " print(data[['Close', 'RSI', 'Signal', 'Strategy', 'Cumulative']].tail())\n", + " \n", + " logging.info(\"Analysis complete!\")\n", + "\"\"\"\n", + "\n", + "output = run_python(test_fixed_code)\n", + "print(\"Output from execution:\")\n", + "print(output)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Test the fix for indentation error in plotting code\n", + "\n", + "test_code_with_plotting = \"\"\"\n", + "import matplotlib.pyplot as plt\n", + "import numpy as np\n", + "import logging\n", + "\n", + "# Configure logging\n", + "logging.basicConfig(\n", + " level=logging.INFO,\n", + " format='[%(asctime)s] %(levelname)s: %(message)s',\n", + " datefmt='%H:%M:%S'\n", + ")\n", + "\n", + "# Simple plotting function\n", + "def create_plot():\n", + " # Generate some data\n", + " x = np.linspace(0, 10, 100)\n", + " y = np.sin(x)\n", + " \n", + " # Create plot\n", + " logging.info(\"Creating sine wave plot\")\n", + " plt.figure(figsize=(10, 6))\n", + " plt.plot(x, y)\n", + " plt.title('Sine Wave')\n", + " plt.xlabel('X')\n", + " plt.ylabel('Y')\n", + " plt.grid(True)\n", + " plt.savefig('sine_wave.png')\n", + " logging.info(\"Plot saved to sine_wave.png\")\n", + "\n", + "if __name__ == \"__main__\":\n", + " create_plot()\n", + "\"\"\"\n", + "\n", + "# Apply safety features to the code\n", + "enhanced_code = add_safety_features(test_code_with_plotting)\n", + "print(\"Code with safety features applied:\")\n", + "print(enhanced_code)\n", + "\n", + "# Run the enhanced code\n", + "output = run_python(enhanced_code)\n", + "print(\"\\nExecution output:\")\n", + "print(output)\n" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "base", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.13.5" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/week4/community-contributions/week4_exercise_solution-Stephen.ipynb b/week4/community-contributions/week4_exercise_solution-Stephen.ipynb new file mode 100644 index 0000000..07d5155 --- /dev/null +++ b/week4/community-contributions/week4_exercise_solution-Stephen.ipynb @@ -0,0 +1,180 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "ed8c52b6", + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "from dotenv import load_dotenv\n", + "from openai import OpenAI\n", + "import gradio as gr\n", + "\n", + "load_dotenv(override=True)\n", + "\n", + "openai_api_key = os.getenv('OPENAI_API_KEY')\n", + "ollama_api_key = os.getenv('OLLAMA_API_KEY')\n", + "\n", + "if openai_api_key:\n", + " print(f\"OpenAI API Key exists and begins {openai_api_key[:8]}\")\n", + "else:\n", + " print(\"OpenAI API Key not set\")\n", + "\n", + "if ollama_api_key:\n", + " print(f\"OLLAMA API Key exists and begins {ollama_api_key[:2]}\")\n", + "else:\n", + " print(\"OLLAMA API Key not set (and this is optional)\")\n", + "\n", + "ollama_url = \"http://localhost:11434/v1\"\n", + "\n", + "openai = OpenAI()\n", + "ollama = OpenAI(api_key=ollama_api_key, base_url=ollama_url)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 38, + "id": "c628f95e", + "metadata": {}, + "outputs": [], + "source": [ + "system_prompt_doc = \"\"\"You are an expert Python developer and code reviewer.\n", + "Your job is to read the user's provided function, and return:\n", + "1. A concise, PEP-257-compliant docstring summarizing what the function does, clarifying types, parameters, return values, and side effects.\n", + "2. Helpful inline comments that improve both readability and maintainability, without restating what the code obviously does.\n", + "\n", + "Only output the function, not explanations or additional text. \n", + "Do not modify variable names or refactor the function logic.\n", + "Your response should improve the code's clarity and documentation, making it easier for others to understand and maintain.\n", + "Don't be extremely verbose.\n", + "Your answer should be at a senior level of expertise.\n", + "\"\"\"\n", + "\n", + "system_prompt_tests = \"\"\"You are a seasoned Python developer and testing expert.\n", + "Your task is to read the user's provided function, and generate:\n", + "1. A concise set of meaningful unit tests that thoroughly validate the function's correctness, including typical, edge, and error cases.\n", + "2. The tests should be written for pytest (or unittest if pytest is not appropriate), use clear, descriptive names, and avoid unnecessary complexity.\n", + "3. If dependencies or mocking are needed, include minimal necessary setup code (but avoid over-mocking).\n", + "\n", + "Only output the relevant test code, not explanations or extra text.\n", + "Do not change the original function; focus solely on comprehensive, maintainable test coverage that other developers can easily understand and extend.\n", + "\"\"\"" + ] + }, + { + "cell_type": "code", + "execution_count": 39, + "id": "4bb84e6c", + "metadata": {}, + "outputs": [], + "source": [ + "models = [\"gpt-4.1-mini\", \"llama3.1\"]\n", + "clients = {\"gpt-4.1-mini\": openai, \"llama3.1\": ollama}\n", + "\n", + "def generate_documentation(code, model):\n", + " response = clients[model].chat.completions.create(\n", + " model=model,\n", + " messages=[\n", + " {\"role\": \"system\", \"content\": system_prompt_doc},\n", + " {\"role\": \"user\", \"content\": code}\n", + " ],\n", + " stream=True\n", + " )\n", + " output = \"\"\n", + " for chunk in response:\n", + " output += chunk.choices[0].delta.content or \"\"\n", + " yield output.replace(\"```python\", \"\").replace(\"```\", \"\")\n", + "\n", + "def generate_tests(code, model):\n", + " response = clients[model].chat.completions.create(\n", + " model=model,\n", + " messages=[\n", + " {\"role\": \"system\", \"content\": system_prompt_tests},\n", + " {\"role\": \"user\", \"content\": code}\n", + " ],\n", + " stream=True\n", + " )\n", + " output = \"\"\n", + " for chunk in response:\n", + " output += chunk.choices[0].delta.content or \"\"\n", + " yield output.replace(\"```python\", \"\").replace(\"```\", \"\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a4e65b26", + "metadata": {}, + "outputs": [], + "source": [ + "with gr.Blocks(theme=gr.themes.Soft(spacing_size=gr.themes.sizes.spacing_sm, radius_size=gr.themes.sizes.radius_none)) as ui:\n", + " gr.Markdown(\"# Python Toolbox\", elem_id=\"app-title\")\n", + " \n", + " with gr.Tab(\"Docstring Generator\") as tab1:\n", + " gr.Markdown(\"## Docstring & Comment Generator\")\n", + " gr.Markdown(\"Paste your function below to generate helpful docstrings and inline comments!\")\n", + "\n", + " with gr.Row():\n", + " with gr.Column():\n", + " code_input = gr.Code(label=\"Your Python function here\", lines=20, language=\"python\")\n", + " model_dropdown = gr.Dropdown(choices=models, value=models[0], label=\"Select model\")\n", + " submit_doc_btn = gr.Button(\"Generate docstring & comments\")\n", + " with gr.Column():\n", + " code_output = gr.Code(label=\"New function with docstring and comments\", language=\"python\")\n", + "\n", + " submit_doc_btn.click(\n", + " generate_documentation, \n", + " inputs=[code_input, model_dropdown], \n", + " outputs=code_output\n", + " )\n", + "\n", + " with gr.Tab(\"Unit Tests Generator\") as tab2:\n", + " gr.Markdown(\"## Unit Test Generator\")\n", + " gr.Markdown(\"Paste your function below to generate helpful unit tests!\")\n", + "\n", + " with gr.Row():\n", + " with gr.Column():\n", + " code_input_2 = gr.Code(label=\"Your Python function here\", lines=20, language=\"python\")\n", + " model_dropdown_2 = gr.Dropdown(choices=models, value=models[0], label=\"Select model\")\n", + " submit_test_btn = gr.Button(\"Generate unit tests\")\n", + " with gr.Column():\n", + " code_output_2 = gr.Code(label=\"Generated unit tests\", language=\"python\")\n", + "\n", + " submit_test_btn.click(\n", + " generate_tests, \n", + " inputs=[code_input_2, model_dropdown_2], \n", + " outputs=code_output_2\n", + " )\n", + " \n", + " \n", + " tab1.select(lambda x: x, inputs=code_input_2, outputs=code_input)\n", + " tab2.select(lambda x: x, inputs=code_input, outputs=code_input_2)\n", + "\n", + "ui.launch(share=False, inbrowser=True)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": ".venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.12" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/week5/community-contributions/Cosmus_Week5_Exercise.ipynb b/week5/community-contributions/Cosmus_Week5_Exercise.ipynb new file mode 100644 index 0000000..ef3da6f --- /dev/null +++ b/week5/community-contributions/Cosmus_Week5_Exercise.ipynb @@ -0,0 +1,307 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "d04a7c55", + "metadata": {}, + "outputs": [], + "source": [ + "#Importing necessary libraries\n", + "import os\n", + "from dotenv import load_dotenv\n", + "from anthropic import Client\n", + "from dotenv import load_dotenv\n", + "import sys\n", + "from faker import Faker\n", + "import random\n", + "import gradio as gr\n", + "from langchain_community.document_loaders import DirectoryLoader, TextLoader\n", + "from langchain_text_splitters import CharacterTextSplitter\n", + "from langchain_community.embeddings import HuggingFaceEmbeddings\n", + "from langchain_community.vectorstores import Chroma\n", + "from langchain_anthropic import ChatAnthropic\n", + "from langchain_classic.memory import ConversationBufferMemory\n", + "from langchain_classic.chains import ConversationalRetrievalChain\n", + "\n", + "!{sys.executable} -m pip install faker\n", + "\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3d7f8354", + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "# loading the .env variables\n", + "load_dotenv(override=True)\n", + "\n", + "# Force export to OS env so LangChain can detect it (had to try this because the key was not loading at some point but by the time i shared the code it loaded well so i commented it out)\n", + "#os.environ[\"ANTHROPIC_API_KEY\"] = os.getenv(\"ANTHROPIC_API_KEY\")\n", + "\n", + "#getting the key from the our .env file. It is Anthropic_API_KEY\n", + "ANTHROPIC_KEY = os.getenv(\"ANTHROPIC_API_KEY\")\n", + "client = Client(api_key=ANTHROPIC_KEY)\n", + "\n", + "# Checking the anthropic models list our anthropic key ca help us play with\n", + "models = client.models.list()\n", + "for model in models:\n", + " print(model.id)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "20d11d1c", + "metadata": {}, + "outputs": [], + "source": [ + "#Getting the python executable path on my notebook to know where to install the faker library\n", + "print(sys.executable)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "93a8f3ec", + "metadata": {}, + "outputs": [], + "source": [ + "#Creating a fake person with faker\n", + "fake = Faker()\n", + "base_dir = \"knowledge_base\"\n", + "folders = [\"personal\", \"projects\", \"learning\"]\n", + "\n", + "# We now create folders if they don't exist\n", + "for folder in folders:\n", + " os.makedirs(f\"{base_dir}/{folder}\", exist_ok=True)\n", + "\n", + "# Check if data already exists\n", + "personal_file = f\"{base_dir}/personal/info.md\"\n", + "projects_file = f\"{base_dir}/projects/projects.md\"\n", + "learning_file = f\"{base_dir}/learning/learning.md\"\n", + "\n", + "#If the personal info file does not exist, create it\n", + "if not os.path.exists(personal_file):\n", + " name = fake.name()\n", + " profession = random.choice([\"Data Analyst\", \"Business Analyst\", \"Software Engineer\", \"AI Specialist\"])\n", + " bio = fake.paragraph(nb_sentences=5)\n", + " experience = \"\\n\".join([f\"- {fake.job()} at {fake.company()} ({fake.year()})\" for _ in range(3)])\n", + " \n", + " personal_text = f\"\"\"\n", + "# Personal Profile\n", + "Name: {name} \n", + "Profession: {profession} \n", + "\n", + "Bio: {bio}\n", + "\n", + "## Experience\n", + "{experience}\n", + "\"\"\"\n", + " with open(personal_file, \"w\") as f:\n", + " f.write(personal_text)\n", + " print(\"Personal info generated.\")\n", + "else:\n", + " #If the personal info file exists, skip the regeneration\n", + " print(\"ℹPersonal info already exists. Skipping regeneration.\")\n", + "\n", + "#doing the same for project file\n", + "if not os.path.exists(projects_file):\n", + " projects = \"\\n\".join([\n", + " f\"- **{fake.catch_phrase()}** — {fake.bs().capitalize()} for {fake.company()}.\"\n", + " for _ in range(5)\n", + " ])\n", + " projects_text = f\"\"\"\n", + "# Projects Portfolio\n", + "\n", + "Key Projects:\n", + "{projects}\n", + "\"\"\"\n", + " with open(projects_file, \"w\") as f:\n", + " f.write(projects_text)\n", + " print(\"Projects generated.\")\n", + "else:\n", + " print(\"ℹProjects already exist. Skipping regeneration.\")\n", + "\n", + "#same thing for learning file\n", + "if not os.path.exists(learning_file):\n", + " topics = [\"LangChain\", \"RAG Systems\", \"Vector Databases\", \"AI Ethics\", \"Prompt Engineering\", \"Data Visualization\"]\n", + " learning = \"\\n\".join([\n", + " f\"- {random.choice(topics)} — {fake.sentence(nb_words=8)}\"\n", + " for _ in range(6)\n", + " ])\n", + " learning_text = f\"\"\"\n", + "# Learning Journey\n", + "\n", + "Recent Topics and Notes:\n", + "{learning}\n", + "\"\"\"\n", + " with open(learning_file, \"w\") as f:\n", + " f.write(learning_text)\n", + " print(\"Learning notes generated.\")\n", + "else:\n", + " print(\"ℹLearning notes already exist. Skipping regeneration.\")\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6fa19091", + "metadata": {}, + "outputs": [], + "source": [ + "#loading the knowledge information from the knowledge_base folder\n", + "loader = DirectoryLoader(\"knowledge_base\", glob=\"**/*.md\", loader_cls=TextLoader)\n", + "documents = loader.load()\n", + "\n", + "#Splitting the documents into chunks\n", + "splitter = CharacterTextSplitter(chunk_size=100, chunk_overlap=80)\n", + "chunks = splitter.split_documents(documents)\n", + "\n", + "print(f\"Loaded {len(documents)} documents and created {len(chunks)} chunks.\")\n" + ] + }, + { + "cell_type": "markdown", + "id": "7b9fc9a5", + "metadata": {}, + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6dcdec41", + "metadata": {}, + "outputs": [], + "source": [ + "#Creating the embeddings\n", + "embeddings = HuggingFaceEmbeddings(model_name=\"sentence-transformers/all-MiniLM-L6-v2\")\n", + "\n", + "# Chroma as the vector store\n", + "vectorstore = Chroma.from_documents(chunks, embeddings, persist_directory=\"chroma_db\")\n", + "vectorstore.persist()\n", + "\n", + "print(\"Vector store created and saved to 'chroma_db'.\")\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "99e4a99f", + "metadata": {}, + "outputs": [], + "source": [ + "#Check Langchain version as they updated the version recently thus making it difficult to use it successfullt\n", + "print(langchain.__version__)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5dc1b6ce", + "metadata": {}, + "outputs": [], + "source": [ + "# The main Langchain Abstraction are: Memory, LLM, and Retriever\n", + "\n", + "# Memory for conversation history\n", + "memory = ConversationBufferMemory(\n", + " memory_key=\"chat_history\",\n", + " return_messages=True\n", + ")\n", + "\n", + "# Using one of the Anthropic models from the list above to create the LLM\n", + "llm = ChatAnthropic(\n", + " model=\"claude-sonnet-4-5-20250929\",\n", + " temperature=0.6,\n", + " max_tokens=1024,\n", + " anthropic_api_key=ANTHROPIC_KEY\n", + ")\n", + "\n", + "# Retriever from your vectorstore\n", + "retriever = vectorstore.as_retriever(search_kwargs={\"k\": 3})\n", + "\n", + "# Bringing everything together tConversational RAG Chain\n", + "conversation_chain = ConversationalRetrievalChain.from_llm(\n", + " llm=llm,\n", + " retriever=retriever,\n", + " memory=memory\n", + ")\n", + "\n", + "print(\"Anthropic conversational retriever is ready!\")\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6f93eea7", + "metadata": {}, + "outputs": [], + "source": [ + "#fnc to create a chat interface\n", + "def chat(message, history):\n", + " if conversation_chain:\n", + " result = conversation_chain.invoke({\"question\": message})\n", + " return result[\"answer\"]\n", + " else:\n", + " # Retrieval-only fallback\n", + " docs = retriever.get_relevant_documents(message)\n", + " context = \"\\n\\n\".join([d.page_content for d in docs])\n", + " return f\"(Offline Mode)\\nTop relevant info:\\n\\n{context[:1000]}\"\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "aadf91b4", + "metadata": {}, + "outputs": [], + "source": [ + "#used som css to make the chat interface look better, and dark mode. I love dark mode btw\n", + "css = \"\"\"\n", + "body {background-color: #0f1117; color: #e6e6e6;}\n", + ".gradio-container {background-color: #0f1117 !important;}\n", + "textarea, input, .wrap.svelte-1ipelgc {background-color: #1b1f2a !important; color: #ffffff !important;}\n", + "\"\"\"\n", + "\n", + "#Gradio blocks\n", + "with gr.Blocks(css=css, theme=\"gradio/monochrome\") as demo:\n", + " gr.Markdown(\n", + " \"\"\"\n", + "

Personal Knowledge Worker

\n", + "

Chat with your auto-generated knowledge base (Claude-powered if available)

\n", + " \"\"\",\n", + " elem_id=\"title\"\n", + " )\n", + " gr.ChatInterface(chat, type=\"messages\")\n", + "\n", + "demo.launch(inbrowser=True)\n" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": ".venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.13.5" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/week5/community-contributions/dkisselev-zz/Week5_Excerise_EmailTerminator.ipynb b/week5/community-contributions/dkisselev-zz/Week5_Excerise_EmailTerminator.ipynb new file mode 100644 index 0000000..fded773 --- /dev/null +++ b/week5/community-contributions/dkisselev-zz/Week5_Excerise_EmailTerminator.ipynb @@ -0,0 +1,1911 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "view-in-github", + "colab_type": "text" + }, + "source": [ + "\"Open" + ] + }, + { + "cell_type": "markdown", + "source": [ + "# Gmail Terminator\n", + "\n", + "## An Intelligent Email Management System\n", + "\n", + "This application uses RAG (Retrieval Augmented Generation) and LLMs to analyze your Gmail inbox, identify important topics and interests, and help you safely delete unimportant emails with archiving.\n", + "\n", + "### Features:\n", + "- **IMAP Authentication**: Secure app-specific password authentication\n", + "- **Vector Embeddings**: OpenAI or BERT/HuggingFace models\n", + "- **Topic Analysis**: LLM-powered identification of your interests\n", + "- **Category Counts**: See breakdown of email categories\n", + "- **Chat-Based Topics Updates**: Use chat to find specific topics of interest\n", + "- **Selective Deletion**: Choose specific emails to delete with checkboxes\n", + "- **Safe Deletion**: Automatic archiving before deletion\n", + "- **Testing Mode**: Process limited emails with debug output\n", + "\n", + "### Architecture:\n", + "1. Connect to Gmail via IMAP\n", + "2. Fetch and parse emails\n", + "3. Chunk text and create embeddings\n", + "4. Store vectors in ChromaDB\n", + "5. Use LLM to identify important topics\n", + "6. Classify emails as keep/delete\n", + "7. Select specific emails to delete\n", + "8. Archive and safely delete selected emails\n", + "\n", + "## Setup Instructions\n", + "\n", + "### IMAP with App-Specific Password\n", + "\n", + "1. **Enable 2-Factor Authentication** on your Google account (required for app passwords)\n", + "2. **Create App-Specific Password**\n", + " - Go to [Google Account Security](https://myaccount.google.com/security)\n", + " - Under \"2-Step Verification\", find \"App passwords\"\n", + " - Generate a new app password for \"Mail\"\n", + "3. **Store Credentials**\n", + " - **Google Colab**: Store as secrets named `EMAIL` and `IMAP_PASSWORD`\n", + " - **Local**: Add to `.env` file:\n", + " ```\n", + " EMAIL=your.email@gmail.com\n", + " IMAP_PASSWORD=your_16_char_app_password\n", + " ```\n", + "4. **Connect**: If credentials are stored, they will auto-populate in the UI" + ], + "metadata": { + "id": "ANmiUlCxG4Bh" + }, + "id": "ANmiUlCxG4Bh" + }, + { + "cell_type": "markdown", + "source": [ + "## Install and Setup" + ], + "metadata": { + "id": "NzQyA5qmu5fv" + }, + "id": "NzQyA5qmu5fv" + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6f9842a8", + "metadata": { + "id": "6f9842a8" + }, + "outputs": [], + "source": [ + "%pip install -U -q imapclient langchain langchain-openai langchain-chroma langchain-community langchain-core langchain-text-splitters langchain-huggingface chromadb sentence-transformers\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "737e1c9e", + "metadata": { + "id": "737e1c9e" + }, + "outputs": [], + "source": [ + "# Standard library imports\n", + "import os\n", + "import json\n", + "import base64\n", + "import zipfile\n", + "import shutil\n", + "from datetime import datetime\n", + "from collections import Counter\n", + "from typing import List, Dict, Optional, Tuple\n", + "from abc import ABC, abstractmethod\n", + "\n", + "# Third-party imports\n", + "import pandas as pd\n", + "import numpy as np\n", + "from tqdm import tqdm\n", + "from bs4 import BeautifulSoup\n", + "\n", + "# IMAP imports\n", + "import imaplib\n", + "import email\n", + "from email.header import decode_header\n", + "\n", + "# LangChain v1.0+ imports\n", + "from langchain_core.documents import Document\n", + "from langchain_core.messages import HumanMessage\n", + "from langchain_text_splitters import CharacterTextSplitter\n", + "from langchain_openai import OpenAIEmbeddings, ChatOpenAI\n", + "from langchain_chroma import Chroma\n", + "from langchain_huggingface import HuggingFaceEmbeddings\n", + "from langchain_core.callbacks import StdOutCallbackHandler\n", + "\n", + "# LLM APIs\n", + "from openai import OpenAI\n", + "\n", + "# HuggingFace\n", + "from huggingface_hub import login\n", + "\n", + "# Gradio\n", + "import gradio as gr\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "191dc787", + "metadata": { + "id": "191dc787" + }, + "outputs": [], + "source": [ + "def setup_api_keys():\n", + " try:\n", + " # Try Colab environment first\n", + " from google.colab import userdata\n", + " api_keys = {\n", + " 'openai': userdata.get('OPENAI_API_KEY'),\n", + " 'anthropic': userdata.get('ANTHROPIC_API_KEY'),\n", + " 'google': userdata.get('GOOGLE_API_KEY'),\n", + " 'hf_token': userdata.get('HF_TOKEN')\n", + " }\n", + " email = userdata.get('EMAIL')\n", + " password = userdata.get('IMAP_PASSWORD')\n", + " print(\"✅ Using Colab secrets\")\n", + " except:\n", + " # Fallback to local environment\n", + " from dotenv import load_dotenv\n", + " load_dotenv()\n", + " api_keys = {\n", + " 'openai': os.getenv('OPENAI_API_KEY'),\n", + " 'anthropic': os.getenv('ANTHROPIC_API_KEY'),\n", + " 'google': os.getenv('GOOGLE_API_KEY'),\n", + " 'hf_token': os.getenv('HF_TOKEN')\n", + " }\n", + "\n", + " email = os.getenv('EMAIL', '')\n", + " password = os.getenv('IMAP_PASSWORD', '')\n", + " print(\"✅ Using local .env file\")\n", + "\n", + " # Initialize API clients\n", + " anthropic_url = \"https://api.anthropic.com/v1/\"\n", + " gemini_url = \"https://generativelanguage.googleapis.com/v1beta/openai/\"\n", + "\n", + " clients = {}\n", + " if api_keys['openai']:\n", + " clients['openai'] = OpenAI(api_key=api_keys['openai'])\n", + " if api_keys['anthropic']:\n", + " clients['anthropic'] = OpenAI(api_key=api_keys['anthropic'], base_url=anthropic_url)\n", + " if api_keys['google']:\n", + " clients['google'] = OpenAI(api_key=api_keys['google'], base_url=gemini_url)\n", + " if api_keys['hf_token']:\n", + " login(api_keys['hf_token'])\n", + "\n", + " os.environ['OPENAI_API_KEY'] = api_keys['openai']\n", + " os.environ['ANTHROPIC_API_KEY'] = api_keys['anthropic']\n", + " os.environ['GOOGLE_API_KEY'] = api_keys['google']\n", + "\n", + " return api_keys, clients, email, password\n", + "\n", + "# Initialize API keys and clients\n", + "api_keys, clients, default_email, default_password = setup_api_keys()\n", + "\n", + "# Constants\n", + "MODEL_OPENAI = \"gpt-4o-mini\"\n", + "MODEL_GEMINI = \"gemini-2.5-pro\"\n", + "DB_NAME = \"email_vector_db\"\n" + ] + }, + { + "cell_type": "markdown", + "source": [ + "##Helper Functions" + ], + "metadata": { + "id": "hUiNY8_I8ac0" + }, + "id": "hUiNY8_I8ac0" + }, + { + "cell_type": "code", + "source": [ + "def get_header_value(headers, name):\n", + " \"\"\"Get header value from email headers.\"\"\"\n", + " for header in headers:\n", + " if header['name'].lower() == name.lower():\n", + " return header['value']\n", + " return \"\"" + ], + "metadata": { + "id": "Y4MjoYtb8b4i" + }, + "id": "Y4MjoYtb8b4i", + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "##Gmail Connection Classes" + ], + "metadata": { + "id": "g7F4Xgw98jec" + }, + "id": "g7F4Xgw98jec" + }, + { + "cell_type": "code", + "source": [ + "class GmailConnection(ABC):\n", + " \"\"\"Abstract base class for Gmail connections.\"\"\"\n", + "\n", + " def __init__(self):\n", + " self.connection = None\n", + " self.auth_info = None\n", + "\n", + " @abstractmethod\n", + " def connect(self) -> bool:\n", + " pass\n", + "\n", + " def fetch_emails(self, max_emails: Optional[int] = None) -> Tuple[List[Document], str]:\n", + " \"\"\"Fetch emails. Returns (documents, diagnostic_message).\"\"\"\n", + " pass\n", + "\n", + " @abstractmethod\n", + " def delete_emails(self, documents: List[Document]) -> Tuple[int, int]:\n", + " pass\n", + "\n", + " def get_auth_info(self) -> Dict:\n", + " return self.auth_info\n", + "\n", + " def is_connected(self) -> bool:\n", + " return self.connection is not None\n", + "\n", + "\n", + "class IMAPConnection(GmailConnection):\n", + " \"\"\"IMAP Gmail connection.\n", + "\n", + " IMPORTANT: For proper email deletion with Gmail IMAP, configure these settings:\n", + " 1. Go to Gmail Settings → Forwarding and POP/IMAP tab\n", + " 2. Under \"When I mark a message in IMAP as deleted\":\n", + " - Set to \"Auto-Expunge off - Wait for the client to update the server\"\n", + " 3. Under \"When a message is marked as deleted and expunged from the last visible IMAP folder\":\n", + " - Select \"Move the message to the Trash\"\n", + " 4. Make sure \"Trash\" label is set to \"Show in IMAP\" under Labels settings\n", + "\n", + " This ensures deleted emails are properly moved to Trash when expunged.\n", + " \"\"\"\n", + "\n", + " def __init__(self, email_address: str, app_password: str):\n", + " super().__init__()\n", + " self.email_address = email_address\n", + " self.app_password = app_password\n", + "\n", + " def connect(self) -> bool:\n", + " \"\"\"Authenticate with Gmail using IMAP.\"\"\"\n", + " try:\n", + " imaplib._MAXLINE = 10000000 # 10MB\n", + "\n", + " self.connection = imaplib.IMAP4_SSL(\"imap.gmail.com\", 993)\n", + " self.connection.login(self.email_address, self.app_password)\n", + "\n", + " status, messages = self.connection.select(\"INBOX\")\n", + " if status == \"OK\":\n", + " self.auth_info = {\n", + " 'email': self.email_address,\n", + " 'total_messages': int(messages[0]),\n", + " 'auth_method': 'IMAP'\n", + " }\n", + "\n", + " print(f\"✓ IMAP connected as: {self.email_address}\")\n", + " print(f\"✓ Total messages in INBOX: {self.auth_info['total_messages']:,}\")\n", + " return True\n", + " else:\n", + " print(f\"❌ Failed to select INBOX: {status}\")\n", + " return False\n", + "\n", + " except Exception as e:\n", + " print(f\"❌ IMAP authentication failed: {e}\")\n", + " print(\"Make sure you're using an app-specific password.\")\n", + " return False\n", + "\n", + " def fetch_emails(self, max_emails: Optional[int] = None) -> Tuple[List[Document], str]:\n", + " \"\"\"Fetch emails using IMAP with UIDs. Returns (documents, diagnostic_message).\"\"\"\n", + " if not self.connection:\n", + " raise RuntimeError(\"Not connected. Call connect() first.\")\n", + "\n", + " diagnostics = [] # Capture diagnostic messages\n", + "\n", + " try:\n", + " self.connection.select(\"INBOX\")\n", + "\n", + " status, messages = self.connection.uid('search', None, \"ALL\")\n", + "\n", + " if status != \"OK\":\n", + " msg = f\"❌ Search failed with status: {status}\"\n", + " diagnostics.append(msg)\n", + " return [], \"\\n\".join(diagnostics)\n", + "\n", + " msg_uids = messages[0].split()\n", + " diagnostics.append(f\"✓ Found {len(msg_uids)} message UIDs\")\n", + "\n", + " if not msg_uids:\n", + " diagnostics.append(\"❌ No message UIDs returned from search\")\n", + " return [], \"\\n\".join(diagnostics)\n", + "\n", + " if max_emails:\n", + " msg_uids = msg_uids[-max_emails:] # Get most recent\n", + " diagnostics.append(f\" → Limited to {len(msg_uids)} most recent emails\")\n", + "\n", + " diagnostics.append(f\"Fetching {len(msg_uids)} emails...\")\n", + " documents = []\n", + " errors = []\n", + "\n", + " for uid in tqdm(msg_uids, desc=\"Processing emails\"):\n", + " try:\n", + " # Fetch using UID to get both UID and the email content\n", + " status, msg_data = self.connection.uid('fetch', uid, \"(RFC822)\")\n", + " if status != \"OK\":\n", + " errors.append(f\"Fetch failed for UID {uid}: {status}\")\n", + " continue\n", + "\n", + " # Check if msg_data is valid\n", + " if not msg_data or not msg_data[0] or len(msg_data[0]) < 2:\n", + " errors.append(f\"Invalid msg_data for UID {uid}\")\n", + " continue\n", + "\n", + " email_message = email.message_from_bytes(msg_data[0][1])\n", + "\n", + " # Extract headers\n", + " subject = email_message.get(\"Subject\", \"\")\n", + " if subject:\n", + " decoded = decode_header(subject)[0]\n", + " if isinstance(decoded[0], bytes):\n", + " subject = decoded[0].decode(decoded[1] or 'utf-8', errors='ignore')\n", + " else:\n", + " subject = decoded[0]\n", + "\n", + " sender = email_message.get(\"From\", \"\")\n", + " recipient = email_message.get(\"To\", \"\")\n", + " date_str = email_message.get(\"Date\", \"\")\n", + "\n", + " # Extract body\n", + " body = \"\"\n", + " if email_message.is_multipart():\n", + " for part in email_message.walk():\n", + " if part.get_content_type() == \"text/plain\":\n", + " try:\n", + " payload = part.get_payload(decode=True)\n", + " if payload:\n", + " body = payload.decode('utf-8', errors='ignore')\n", + " break\n", + " except Exception as e:\n", + " continue\n", + " elif part.get_content_type() == \"text/html\" and not body:\n", + " try:\n", + " payload = part.get_payload(decode=True)\n", + " if payload:\n", + " html = payload.decode('utf-8', errors='ignore')\n", + " body = BeautifulSoup(html, 'html.parser').get_text()\n", + " except Exception as e:\n", + " continue\n", + " else:\n", + " try:\n", + " payload = email_message.get_payload(decode=True)\n", + " if payload:\n", + " body = payload.decode('utf-8', errors='ignore')\n", + " if email_message.get_content_type() == \"text/html\":\n", + " body = BeautifulSoup(body, 'html.parser').get_text()\n", + " else:\n", + " # Try without decoding for plain text\n", + " body = str(email_message.get_payload())\n", + " except Exception as e:\n", + " # Last resort: use subject as body\n", + " body = \"\"\n", + "\n", + " # Clean whitespace\n", + " if body:\n", + " body = ' '.join(body.split())\n", + "\n", + " # Use subject if body is empty or too short\n", + " if not body or len(body) < 10:\n", + " body = subject or \"No content\"\n", + "\n", + " content = f\"Subject: {subject}\\nFrom: {sender}\\nTo: {recipient}\\nDate: {date_str}\\n\\n{body}\"\n", + "\n", + " doc = Document(\n", + " page_content=content,\n", + " metadata={\n", + " 'uid': uid.decode(),\n", + " 'message_id': uid.decode(),\n", + " 'subject': subject,\n", + " 'sender': sender,\n", + " 'recipient': recipient,\n", + " 'date': date_str,\n", + " 'source': 'gmail_imap'\n", + " }\n", + " )\n", + " documents.append(doc)\n", + "\n", + " except Exception as e:\n", + " errors.append(f\"Error processing UID {uid}: {str(e)}\")\n", + " continue\n", + "\n", + " diagnostics.append(f\"✓ Successfully fetched {len(documents)} emails out of {len(msg_uids)} attempted\")\n", + "\n", + " if errors:\n", + " diagnostics.append(f\"\\n⚠️ Encountered {len(errors)} errors:\")\n", + " # Show first 5 errors\n", + " for err in errors[:5]:\n", + " diagnostics.append(f\" • {err}\")\n", + " if len(errors) > 5:\n", + " diagnostics.append(f\" ... and {len(errors) - 5} more errors\")\n", + "\n", + " if len(documents) == 0 and len(msg_uids) > 0:\n", + " diagnostics.append(\"\\n⚠️ WARNING: No documents created despite having UIDs\")\n", + "\n", + " return documents, \"\\n\".join(diagnostics)\n", + "\n", + " except Exception as error:\n", + " diagnostics.append(f\"❌ Fetch error: {error}\")\n", + " import traceback\n", + " diagnostics.append(f\"\\nTraceback:\\n{traceback.format_exc()}\")\n", + " return [], \"\\n\".join(diagnostics)\n", + "\n", + " def delete_emails(self, documents: List[Document]) -> Tuple[int, int]:\n", + " \"\"\"Delete emails using IMAP with proper UID handling for Gmail.\n", + "\n", + " This method works with Gmail's \"Auto-Expunge off\" setting by:\n", + " 1. Using UIDs instead of sequence numbers for reliable identification\n", + " 2. Marking emails with \\\\Deleted flag\n", + " 3. Explicitly calling EXPUNGE to permanently remove them\n", + " 4. Moving emails to [Gmail]/Trash (Gmail's default behavior)\n", + " \"\"\"\n", + " if not self.connection:\n", + " raise RuntimeError(\"Not connected. Call connect() first.\")\n", + "\n", + " if not documents:\n", + " return 0, 0\n", + "\n", + " successful, failed = 0, 0\n", + " print(f\"Deleting {len(documents)} emails via IMAP...\")\n", + "\n", + " try:\n", + " # Select INBOX in read-write mode (default)\n", + " status, response = self.connection.select(\"INBOX\")\n", + " if status != \"OK\":\n", + " print(f\"❌ Failed to select INBOX: {response}\")\n", + " return 0, len(documents)\n", + "\n", + " for doc in tqdm(documents, desc=\"Marking emails for deletion\"):\n", + " # Try to get UID first, fall back to message_id\n", + " uid = doc.metadata.get('uid') or doc.metadata.get('message_id')\n", + " if not uid:\n", + " print(f\"⚠️ No UID found for email: {doc.metadata.get('subject', 'Unknown')}\")\n", + " failed += 1\n", + " continue\n", + "\n", + " try:\n", + " # Convert to bytes if it's a string\n", + " if isinstance(uid, str):\n", + " uid = uid.encode()\n", + "\n", + " # Use UID STORE to mark the email as deleted\n", + " # This is more reliable than using sequence numbers\n", + " status, response = self.connection.uid('STORE', uid, '+FLAGS', '(\\\\Deleted)')\n", + "\n", + " if status == \"OK\":\n", + " successful += 1\n", + " else:\n", + " print(f\"⚠️ Failed to mark UID {uid.decode()}: {response}\")\n", + " failed += 1\n", + "\n", + " except Exception as e:\n", + " print(f\"❌ Error deleting UID {uid}: {e}\")\n", + " failed += 1\n", + "\n", + " # Expunge to permanently delete all messages marked as \\\\Deleted\n", + " # With Gmail's \"Auto-Expunge off\", this command is required\n", + " print(f\"\\n📤 Expunging {successful} deleted emails...\")\n", + " try:\n", + " status, response = self.connection.expunge()\n", + " if status == \"OK\":\n", + " print(f\"✓ Expunge successful: {response}\")\n", + " else:\n", + " print(f\"⚠️ Expunge response: {status} - {response}\")\n", + " except Exception as e:\n", + " print(f\"❌ Expunge error: {e}\")\n", + "\n", + " # Close and reselect to ensure changes are committed\n", + " try:\n", + " self.connection.close()\n", + " self.connection.select(\"INBOX\")\n", + " except:\n", + " pass # Not critical if this fails\n", + "\n", + " print(f\"\\n✓ Deletion complete: {successful} successful, {failed} failed\")\n", + " if successful > 0:\n", + " print(f\"ℹ️ With Gmail's settings, deleted emails should appear in [Gmail]/Trash\")\n", + "\n", + " return successful, failed\n", + "\n", + " except Exception as error:\n", + " print(f\"❌ Delete operation error: {error}\")\n", + " return successful, failed\n", + "\n", + "\n", + "def create_gmail_connection(email: str, password: str) -> GmailConnection:\n", + " \"\"\"Factory function to create Gmail connection.\"\"\"\n", + " if not email or not password:\n", + " raise ValueError(\"Email and password required for IMAP\")\n", + " return IMAPConnection(email, password)" + ], + "metadata": { + "id": "Mv4m2UqV8i-b" + }, + "id": "Mv4m2UqV8i-b", + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "##Vector Database Manager" + ], + "metadata": { + "id": "WI1_7UiU8iy3" + }, + "id": "WI1_7UiU8iy3" + }, + { + "cell_type": "code", + "source": [ + "class VectorDatabaseManager:\n", + " \"\"\"Manages vector database operations for email embeddings.\"\"\"\n", + "\n", + " def __init__(self, db_name: str = DB_NAME):\n", + " self.db_name = db_name\n", + " self.vectorstore = None\n", + " self.embeddings = None\n", + "\n", + " def create_embeddings(self, model_type: str = \"openai\"):\n", + " \"\"\"Create embedding function based on model type.\"\"\"\n", + " if model_type.lower() == \"openai\":\n", + " print(\"Using OpenAI embeddings...\")\n", + " self.embeddings = OpenAIEmbeddings()\n", + " elif model_type.lower() == \"bert\":\n", + " print(\"Using BERT (HuggingFace) embeddings...\")\n", + " self.embeddings = HuggingFaceEmbeddings(\n", + " model_name=\"sentence-transformers/all-MiniLM-L6-v2\"\n", + " )\n", + " else:\n", + " raise ValueError(f\"Unknown model type: {model_type}. Use 'openai' or 'bert'.\")\n", + "\n", + " return self.embeddings\n", + "\n", + " def create_vector_store(self, chunks: List[Document], recreate: bool = True):\n", + " \"\"\"Chroma vector store from document chunks.\"\"\"\n", + " if not self.embeddings:\n", + " raise RuntimeError(\"Call create_embeddings() first\")\n", + "\n", + " if recreate and os.path.exists(self.db_name):\n", + " print(f\"Deleting existing database: {self.db_name}\")\n", + " try:\n", + " Chroma(persist_directory=self.db_name, embedding_function=self.embeddings).delete_collection()\n", + " except:\n", + " pass\n", + "\n", + " print(f\"Creating vector store with {len(chunks)} chunks\")\n", + " self.vectorstore = Chroma.from_documents(\n", + " documents=chunks,\n", + " embedding=self.embeddings,\n", + " persist_directory=self.db_name\n", + " )\n", + "\n", + " count = self.vectorstore._collection.count()\n", + " print(f\"Vector store created with {count:,} documents\")\n", + "\n", + " return self.vectorstore\n", + "\n", + " def load_vector_store(self):\n", + " \"\"\"Load existing Chroma vector store.\"\"\"\n", + " if not self.embeddings:\n", + " raise RuntimeError(\"Call create_embeddings() first\")\n", + "\n", + " if not os.path.exists(self.db_name):\n", + " raise FileNotFoundError(f\"Vector store not found: {self.db_name}\")\n", + "\n", + " self.vectorstore = Chroma(\n", + " persist_directory=self.db_name,\n", + " embedding_function=self.embeddings\n", + " )\n", + "\n", + " count = self.vectorstore._collection.count()\n", + " print(f\"Loaded vector store with {count:,} documents\")\n", + "\n", + " return self.vectorstore\n", + "\n", + " def get_vectorstore(self):\n", + " \"\"\"Get the vectorstore instance.\"\"\"\n", + " return self.vectorstore" + ], + "metadata": { + "id": "R1S1CEwf9VF7" + }, + "id": "R1S1CEwf9VF7", + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "## Email Processor" + ], + "metadata": { + "id": "LWIukSSu9vl_" + }, + "id": "LWIukSSu9vl_" + }, + { + "cell_type": "code", + "source": [ + "class EmailProcessor:\n", + " \"\"\"Email processor\"\"\"\n", + "\n", + " def __init__(self):\n", + " self.documents = []\n", + " self.chunks = []\n", + " self.llm = None\n", + " self.topics = \"\"\n", + " self.classified_emails = {'keep': [], 'delete': []}\n", + " self.topic_to_emails = {}\n", + " self.email_to_topic = {}\n", + "\n", + " def chunk_documents(self, documents: List[Document], chunk_size: int = 1000, chunk_overlap: int = 200):\n", + " \"\"\"Chunk email documents.\"\"\"\n", + " text_splitter = CharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap)\n", + "\n", + " self.documents = documents\n", + " self.chunks = text_splitter.split_documents(documents)\n", + " print(f\"Created {len(self.chunks)} chunks from {len(documents)} documents\")\n", + " return self.chunks\n", + "\n", + " def get_statistics(self, documents: List[Document]) -> Dict:\n", + " \"\"\"Calculate statistics.\"\"\"\n", + " if not documents:\n", + " return {}\n", + "\n", + " senders = [doc.metadata.get('sender', '') for doc in documents]\n", + " total_chars = sum(len(doc.page_content) for doc in documents)\n", + "\n", + " return {\n", + " 'total_emails': len(documents),\n", + " 'total_chars': total_chars,\n", + " 'avg_email_length': total_chars // len(documents),\n", + " 'unique_senders': len(set(senders)),\n", + " 'top_senders': Counter(senders).most_common(10)\n", + " }\n", + "\n", + " def create_llm(self, model_type: str = \"openai\", temperature: float = 0.7, debug: bool = False):\n", + " \"\"\"Create LLM instance.\"\"\"\n", + " callbacks = [StdOutCallbackHandler()] if debug else []\n", + "\n", + " if model_type.lower() == \"openai\":\n", + " self.llm = ChatOpenAI(\n", + " temperature=temperature,\n", + " model_name=MODEL_OPENAI,\n", + " callbacks=callbacks\n", + " )\n", + " else:\n", + " self.llm = ChatOpenAI(temperature=temperature, model_name=MODEL_OPENAI)\n", + "\n", + " return self.llm\n", + "\n", + " def analyze_personal_interests(self, documents: List[Document]) -> str:\n", + " \"\"\"Analyze personal interests using LLM.\"\"\"\n", + " if not self.llm:\n", + " raise RuntimeError(\"Call create_llm() first\")\n", + "\n", + " prompt = self._generate_topics_prompt(documents)\n", + " response = self.llm.invoke([HumanMessage(content=prompt)])\n", + " self.topics = response.content\n", + " return self.topics\n", + "\n", + " def _generate_topics_prompt(self, documents: List[Document], user_context: Optional[str] = None) -> str:\n", + " \"\"\"Generate LLM prompt for topic identification.\"\"\"\n", + " senders = [doc.metadata.get('sender', '') for doc in documents]\n", + " subjects = [doc.metadata.get('subject', '') for doc in documents]\n", + " sender_counts = Counter(senders).most_common(20)\n", + "\n", + " context_line = f'Based on the user\\'s query: \"{user_context}\"\\n\\n' if user_context else \"\"\n", + "\n", + " prompt = f\"\"\"\n", + "{context_line}I have {len(documents)} emails. Analyze and identify 5-10 important topics/categories.\n", + "\n", + "Top senders:\n", + "{chr(10).join([f\"- {sender}: {count}\" for sender, count in sender_counts])}\n", + "\n", + "Sample subjects (first 30):\n", + "{chr(10).join([f\"- {subj}\" for subj in subjects[:30]])}\n", + "\n", + "IMPORTANT: Format your response as a simple numbered list with ONLY the topic names, one per line.\n", + "Do NOT use markdown formatting (**, *, etc.).\n", + "Do NOT add descriptions or explanations after the topic name.\n", + "Do NOT add blank lines between topics.\n", + "\n", + "Example format:\n", + "1. Work Projects\n", + "2. Family Communications\n", + "3. Professional Development\n", + "\"\"\"\n", + "\n", + " if user_context:\n", + " prompt += f\"\\n\\nYour response should list topics that align with the user's query about: {user_context}\"\n", + "\n", + " return prompt\n", + "\n", + " def extract_topics_from_text(self, topics_text: str) -> List[str]:\n", + " \"\"\"Extract topic list from LLM-generated topics text.\"\"\"\n", + " topics = []\n", + " lines = topics_text.strip().split('\\n')\n", + "\n", + " for line in lines:\n", + " line = line.strip()\n", + "\n", + " # Skip empty lines\n", + " if not line or len(line) < 3:\n", + " continue\n", + "\n", + " # Skip lines that are clearly descriptions (start with lowercase, or too long)\n", + " if line[0].islower() or line.startswith(('Emails', 'Topics', 'Information', 'Communications', 'Offers')):\n", + " continue\n", + "\n", + " # Remove markdown formatting (**, *, _)\n", + " line = line.replace('**', '').replace('*', '').replace('_', '')\n", + "\n", + " # Remove numbering and bullet points\n", + " if line and line[0].isdigit():\n", + " # Remove \"1.\" or \"1)\"\n", + " parts = line.split('.', 1)\n", + " if len(parts) > 1:\n", + " line = parts[1].strip()\n", + " else:\n", + " parts = line.split(')', 1)\n", + " if len(parts) > 1:\n", + " line = parts[1].strip()\n", + " elif line.startswith(('-', '•')):\n", + " line = line[1:].strip()\n", + "\n", + " # Take only the topic name (before any dash or colon describing it)\n", + " if ' - ' in line:\n", + " topic = line.split(' - ')[0].strip()\n", + " elif ':' in line:\n", + " topic = line.split(':')[0].strip()\n", + " else:\n", + " topic = line.strip()\n", + "\n", + " # Validate: reasonable length for a topic name (not a full sentence/description)\n", + " # Topic names should be between 5-60 characters\n", + " if topic and 5 < len(topic) < 60 and not topic.lower().startswith('based on'):\n", + " topics.append(topic)\n", + "\n", + " return topics[:10] # Limit to top 10 topics\n", + "\n", + " def categorize_emails_by_topics(self, documents: List[Document], vectorstore) -> Dict[str, List[Document]]:\n", + " \"\"\"Categorize emails by matching them to identified topics using RAG.\"\"\"\n", + " if not self.topics or not vectorstore:\n", + " return {}\n", + "\n", + " # Extract topic list from the topics text\n", + " topic_list = self.extract_topics_from_text(self.topics)\n", + "\n", + " if not topic_list:\n", + " return {}\n", + "\n", + " # For each topic, find matching emails using vector similarity\n", + " topic_to_emails = {topic: [] for topic in topic_list}\n", + " topic_to_emails['Uncategorized'] = []\n", + "\n", + " # Track which emails have been matched to which topic\n", + " matched_email_ids = set()\n", + " email_to_topic = {} # Map message_id to topic name\n", + "\n", + " retriever = vectorstore.as_retriever(search_kwargs={\"k\": len(documents)})\n", + "\n", + " for topic in topic_list:\n", + " # Query vectorstore for emails matching this topic\n", + " query = f\"Emails about: {topic}\"\n", + " relevant_docs = retriever.invoke(query)\n", + "\n", + " # Take top matches (based on proportion of total emails - ~15% per topic)\n", + " num_matches = max(1, int(len(documents) * 0.15))\n", + "\n", + " for doc in relevant_docs[:num_matches]:\n", + " msg_id = doc.metadata.get('message_id')\n", + " if msg_id and msg_id not in matched_email_ids:\n", + " # Find the original document\n", + " original_doc = next((d for d in documents if d.metadata.get('message_id') == msg_id), None)\n", + " if original_doc:\n", + " topic_to_emails[topic].append(original_doc)\n", + " matched_email_ids.add(msg_id)\n", + " email_to_topic[msg_id] = topic\n", + "\n", + " # Add uncategorized emails\n", + " for doc in documents:\n", + " msg_id = doc.metadata.get('message_id')\n", + " if msg_id not in matched_email_ids:\n", + " topic_to_emails['Uncategorized'].append(doc)\n", + " email_to_topic[msg_id] = 'Uncategorized'\n", + "\n", + " # Store the mapping for use in dataframe creation\n", + " self.email_to_topic = email_to_topic\n", + "\n", + " return topic_to_emails\n", + "\n", + " def get_topic_counts_display(self, documents: List[Document], vectorstore) -> str:\n", + " \"\"\"Get formatted topic counts for display.\"\"\"\n", + " if not self.topics or not vectorstore:\n", + " return \"No topics identified yet.\"\n", + "\n", + " topic_to_emails = self.categorize_emails_by_topics(documents, vectorstore)\n", + "\n", + " counts_text = \"Email Counts by Identified Topic:\\n\\n\"\n", + "\n", + " # Sort by count, descending\n", + " sorted_topics = sorted(topic_to_emails.items(), key=lambda x: len(x[1]), reverse=True)\n", + "\n", + " for topic, emails in sorted_topics:\n", + " count = len(emails)\n", + " if count > 0:\n", + " counts_text += f\" {topic}: {count} emails\\n\"\n", + "\n", + " total = sum(len(emails) for emails in topic_to_emails.values())\n", + " counts_text += f\"\\n Total: {total} emails\"\n", + "\n", + " return counts_text\n", + "\n", + " def classify_emails(self, documents: List[Document], vectorstore, threshold: float = 0.5):\n", + " \"\"\"Classify emails based on identified topics.\n", + "\n", + " Emails matching identified topics → KEEP\n", + " Emails not matching any topic → DELETE candidates\n", + " \"\"\"\n", + " if not self.topics:\n", + " raise RuntimeError(\"Call analyze_personal_interests() first\")\n", + "\n", + " # Categorize emails by topics\n", + " topic_to_emails = self.categorize_emails_by_topics(documents, vectorstore)\n", + "\n", + " # Emails matching topics are KEPT\n", + " keep_emails = []\n", + " for topic, emails in topic_to_emails.items():\n", + " if topic != 'Uncategorized':\n", + " keep_emails.extend(emails)\n", + "\n", + " # Uncategorized emails are DELETE candidates\n", + " delete_candidates = topic_to_emails.get('Uncategorized', [])\n", + "\n", + " # Store topic categorization for counts display\n", + " self.topic_to_emails = topic_to_emails\n", + "\n", + " self.classified_emails = {'keep': keep_emails, 'delete': delete_candidates}\n", + "\n", + " print(f\"Classification: {len(keep_emails)} keep, {len(delete_candidates)} delete\")\n", + " print(f\"Matched to {len([t for t in topic_to_emails.keys() if t != 'Uncategorized'])} topics\")\n", + " return self.classified_emails\n", + "\n", + " def create_archive(self, documents: List[Document], archive_name: Optional[str] = None) -> str:\n", + " \"\"\"Create ZIP archive of emails.\"\"\"\n", + " if not documents:\n", + " raise ValueError(\"No documents to archive\")\n", + "\n", + " if not archive_name:\n", + " timestamp = datetime.now().strftime(\"%Y%m%d_%H%M%S\")\n", + " archive_name = f\"email_archive_{timestamp}.zip\"\n", + "\n", + " archive_dir = \"email_archive_temp\"\n", + " os.makedirs(archive_dir, exist_ok=True)\n", + "\n", + " for i, doc in enumerate(documents):\n", + " email_data = {'metadata': doc.metadata, 'content': doc.page_content}\n", + " subject = doc.metadata.get('subject', 'no_subject')[:50]\n", + " safe_subject = \"\".join(c for c in subject if c.isalnum() or c in (' ', '-', '_')).strip()\n", + " filename = f\"{i+1:04d}_{safe_subject}.json\"\n", + "\n", + " with open(os.path.join(archive_dir, filename), 'w', encoding='utf-8') as f:\n", + " json.dump(email_data, f, indent=2, ensure_ascii=False)\n", + "\n", + " # Create ZIP\n", + " with zipfile.ZipFile(archive_name, 'w', zipfile.ZIP_DEFLATED) as zipf:\n", + " for root, dirs, files in os.walk(archive_dir):\n", + " for file in files:\n", + " zipf.write(os.path.join(root, file), file)\n", + "\n", + " shutil.rmtree(archive_dir)\n", + " print(f\"Archive created: {archive_name}\")\n", + " return archive_name\n", + "\n", + " def emails_to_dataframe(self, documents: List[Document], add_select_column: bool = False) -> pd.DataFrame:\n", + " \"\"\"Convert to DataFrame with Topics column.\"\"\"\n", + " data = [\n", + " {\n", + " 'Topics': self.email_to_topic.get(doc.metadata.get('message_id', ''), 'Unknown'),\n", + " 'Message ID': doc.metadata.get('message_id', ''),\n", + " 'Subject': doc.metadata.get('subject', '')[:100],\n", + " 'Sender': doc.metadata.get('sender', ''),\n", + " 'Length': len(doc.page_content)\n", + " }\n", + " for doc in documents\n", + " ]\n", + " df = pd.DataFrame(data)\n", + "\n", + " if add_select_column:\n", + " # Add Select column as first column\n", + " df.insert(0, 'Select', False)\n", + "\n", + " return df" + ], + "metadata": { + "id": "7fUcjkI79vLa" + }, + "id": "7fUcjkI79vLa", + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "##Application State" + ], + "metadata": { + "id": "VWqZZRLY94ST" + }, + "id": "VWqZZRLY94ST" + }, + { + "cell_type": "code", + "source": [ + "class AppState:\n", + " \"\"\"Global application state.\"\"\"\n", + " def __init__(self):\n", + " self.gmail_conn: Optional[GmailConnection] = None\n", + " self.vector_db_manager = VectorDatabaseManager()\n", + " self.email_processor = EmailProcessor()\n", + " self.testing_mode = False\n", + " self.debug_mode = False\n", + "\n", + "state = AppState()" + ], + "metadata": { + "id": "eHKPF6WB93WZ" + }, + "id": "eHKPF6WB93WZ", + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "##Gradio Callback Functions" + ], + "metadata": { + "id": "yOCw1doE93LH" + }, + "id": "yOCw1doE93LH" + }, + { + "cell_type": "code", + "source": [ + "def connect_imap(email, password):\n", + " try:\n", + " state.gmail_conn = create_gmail_connection(email, password)\n", + " if state.gmail_conn.connect():\n", + " info = state.gmail_conn.get_auth_info()\n", + " return f\"Connected as {info['email']}\\nTotal messages: {info['total_messages']:,}\"\n", + " return \"❌ Authentication failed\"\n", + " except Exception as e:\n", + " return f\"❌ Error: {str(e)}\"\n", + "\n", + "\n", + "def connect_imap(email, password):\n", + " try:\n", + " state.gmail_conn = create_gmail_connection(email, password)\n", + " if state.gmail_conn.connect():\n", + " info = state.gmail_conn.get_auth_info()\n", + " return f\"Connected as {info['email']}\\nTotal messages: {info['total_messages']:,}\"\n", + " return \"❌ Authentication failed\"\n", + " except Exception as e:\n", + " return f\"❌ Error: {str(e)}\"\n", + "\n", + "\n", + "def fetch_and_process(testing_mode, embedding_model):\n", + " try:\n", + " if not state.gmail_conn or not state.gmail_conn.is_connected():\n", + " return \"❌ Not authenticated\"\n", + "\n", + " state.testing_mode = testing_mode\n", + " max_emails = 50 if testing_mode else None\n", + "\n", + " documents, fetch_diagnostics = state.gmail_conn.fetch_emails(max_emails)\n", + "\n", + " if not documents:\n", + " return f\"❌ No emails fetched\\n\\n{fetch_diagnostics}\"\n", + "\n", + " stats = state.email_processor.get_statistics(documents)\n", + " chunks = state.email_processor.chunk_documents(documents)\n", + "\n", + " state.vector_db_manager.create_embeddings(embedding_model)\n", + " state.vector_db_manager.create_vector_store(chunks)\n", + "\n", + " return f\"\"\"✓ Processing completed!\n", + "\n", + "{fetch_diagnostics}\n", + "\n", + "Total emails: {stats['total_emails']}\n", + "Chunks created: {len(chunks)}\n", + "Top 5 senders:\n", + "{chr(10).join([f\" - {sender}: {count}\" for sender, count in stats['top_senders'][:5]])}\n", + "\"\"\"\n", + " except Exception as e:\n", + " import traceback\n", + " return f\"❌ Error: {str(e)}\\n\\nTraceback:\\n{traceback.format_exc()}\"\n", + "\n", + "\n", + "def analyze_topics(llm_model, threshold):\n", + " try:\n", + " if not state.email_processor.documents:\n", + " return \"❌ No documents loaded\", \"\", None, None\n", + "\n", + " state.email_processor.create_llm(llm_model)\n", + " topics = state.email_processor.analyze_personal_interests(state.email_processor.documents)\n", + "\n", + " # Automatically classify after analysis\n", + " classified = state.email_processor.classify_emails(\n", + " state.email_processor.documents,\n", + " state.vector_db_manager.vectorstore,\n", + " threshold\n", + " )\n", + "\n", + " # Get topic counts after classification (shows which topics emails matched to)\n", + " counts_text = state.email_processor.get_topic_counts_display(\n", + " state.email_processor.documents,\n", + " state.vector_db_manager.vectorstore\n", + " )\n", + "\n", + " # Get the actual topics list that was used for categorization\n", + " topic_list = state.email_processor.extract_topics_from_text(topics)\n", + " formatted_topics = \"Identified Topics:\\n\\n\" + \"\\n\".join([f\"{i+1}. {topic}\" for i, topic in enumerate(topic_list)])\n", + "\n", + " keep_df = state.email_processor.emails_to_dataframe(classified['keep'], add_select_column=False)\n", + " delete_df = state.email_processor.emails_to_dataframe(classified['delete'], add_select_column=True)\n", + "\n", + " return formatted_topics, counts_text, keep_df, delete_df\n", + " except Exception as e:\n", + " return f\"❌ Error: {str(e)}\", \"\", None, None\n", + "\n", + "\n", + "def refine_topics_with_chat(chat_query, llm_model, threshold):\n", + " \"\"\"Use LLM to identify topics based on user query about their interests.\"\"\"\n", + " try:\n", + " if not state.email_processor.documents or not state.vector_db_manager.vectorstore:\n", + " return \"❌ Please process emails first\", \"\", None, None\n", + "\n", + " if not chat_query or chat_query.strip() == \"\":\n", + " return \"❌ Please enter a query\", \"\", None, None\n", + "\n", + " # Create LLM if needed\n", + " if not state.email_processor.llm:\n", + " state.email_processor.create_llm(llm_model)\n", + "\n", + " prompt = state.email_processor._generate_topics_prompt(\n", + " state.email_processor.documents,\n", + " user_context=chat_query\n", + " )\n", + "\n", + " response = state.email_processor.llm.invoke([HumanMessage(content=prompt)])\n", + " state.email_processor.topics = response.content\n", + "\n", + " # Automatically classify emails based on the new topics\n", + " classified = state.email_processor.classify_emails(\n", + " state.email_processor.documents,\n", + " state.vector_db_manager.vectorstore,\n", + " threshold\n", + " )\n", + "\n", + " # Get topic counts after classification\n", + " counts_text = state.email_processor.get_topic_counts_display(\n", + " state.email_processor.documents,\n", + " state.vector_db_manager.vectorstore\n", + " )\n", + "\n", + " # Get the actual topics list that was used for categorization\n", + " topic_list = state.email_processor.extract_topics_from_text(state.email_processor.topics)\n", + " formatted_topics = \"Identified Topics:\\n\\n\" + \"\\n\".join([f\"{i+1}. {topic}\" for i, topic in enumerate(topic_list)])\n", + "\n", + " keep_df = state.email_processor.emails_to_dataframe(classified['keep'], add_select_column=False)\n", + " delete_df = state.email_processor.emails_to_dataframe(classified['delete'], add_select_column=True)\n", + "\n", + " return formatted_topics, counts_text, keep_df, delete_df\n", + " except Exception as e:\n", + " return f\"❌ Error: {str(e)}\", \"\", None, None\n", + "\n", + "\n", + "def select_all_emails(delete_df):\n", + " \"\"\"Select all delete candidate emails.\"\"\"\n", + " if delete_df is None or len(delete_df) == 0:\n", + " return delete_df\n", + "\n", + " delete_df_copy = delete_df.copy()\n", + " delete_df_copy['Select'] = True\n", + " return delete_df_copy\n", + "\n", + "\n", + "def deselect_all_emails(delete_df):\n", + " \"\"\"Deselect all delete candidate emails.\"\"\"\n", + " if delete_df is None or len(delete_df) == 0:\n", + " return delete_df\n", + "\n", + " delete_df_copy = delete_df.copy()\n", + " delete_df_copy['Select'] = False\n", + " return delete_df_copy\n", + "\n", + "\n", + "def create_archive_file():\n", + " try:\n", + " if not state.email_processor.classified_emails['delete']:\n", + " return \"❌ No emails to archive\", None\n", + "\n", + " archive_path = state.email_processor.create_archive(\n", + " state.email_processor.classified_emails['delete']\n", + " )\n", + " return f\"✓ Archive created: {archive_path}\", archive_path\n", + " except Exception as e:\n", + " return f\"❌ Error: {str(e)}\", None\n", + "\n", + "\n", + "def perform_deletion(confirmation_text, delete_df):\n", + " try:\n", + " if confirmation_text.strip().upper() != \"DELETE\":\n", + " return \"❌ Confirmation failed. Type 'DELETE' to confirm.\"\n", + "\n", + " if delete_df is None or len(delete_df) == 0:\n", + " return \"❌ No emails available for deletion\"\n", + "\n", + " # Get selected emails\n", + " if 'Select' not in delete_df.columns:\n", + " return \"❌ Invalid dataframe format\"\n", + "\n", + " selected_rows = delete_df[delete_df['Select'] == True]\n", + " if len(selected_rows) == 0:\n", + " return \"❌ No emails selected for deletion\"\n", + "\n", + " # Get message IDs of selected emails\n", + " selected_ids = set(selected_rows['Message ID'].tolist())\n", + "\n", + " # Filter documents to only selected ones\n", + " selected_docs = [\n", + " doc for doc in state.email_processor.classified_emails['delete']\n", + " if doc.metadata.get('message_id') in selected_ids\n", + " ]\n", + "\n", + " if not state.gmail_conn:\n", + " return \"❌ Not authenticated\"\n", + "\n", + " success, failed = state.gmail_conn.delete_emails(selected_docs)\n", + "\n", + " return f\"Deletion complete:\\n - Deleted: {success}\\n - Failed: {failed}\\n - Skipped: {len(state.email_processor.classified_emails['delete']) - len(selected_docs)}\"\n", + " except Exception as e:\n", + " return f\"❌ Error: {str(e)}\"" + ], + "metadata": { + "id": "2toGS3_z-dSE" + }, + "id": "2toGS3_z-dSE", + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "##Gradio Interface" + ], + "metadata": { + "id": "ja-oFdo8-h6b" + }, + "id": "ja-oFdo8-h6b" + }, + { + "cell_type": "code", + "source": [ + "with gr.Blocks(title=\"Gmail Inbox Terminator\", theme=gr.themes.Soft()) as app:\n", + " gr.Markdown(\"# 🔥 Gmail Inbox Terminator\")\n", + " gr.Markdown(\"### Intelligent Email Management with AI\")\n", + " gr.Markdown(\"Identify important topics, then delete emails OUTSIDE those topics.\")\n", + "\n", + " with gr.Tabs():\n", + " # Tab 1: Connection\n", + " with gr.Tab(\"🔌 Connection\"):\n", + " gr.Markdown(\"## Connect to Gmail via IMAP\")\n", + "\n", + " if default_email and default_password:\n", + " gr.Markdown(\"\"\"\n", + "**✅ Credentials loaded**\n", + "\n", + "Use pre-filled credentials or enter different ones.\n", + "\"\"\")\n", + " else:\n", + " gr.Markdown(\"\"\"\n", + "**Requirements:**\n", + "1. Enable 2-Factor Authentication on your Google account\n", + "2. Create an app-specific password at [Google Account Security](https://myaccount.google.com/security)\n", + "3. Use the app password below (not your regular password)\n", + "\"\"\")\n", + "\n", + " with gr.Row():\n", + " imap_email = gr.Textbox(\n", + " label=\"Email Address\",\n", + " placeholder=\"your.email@gmail.com\",\n", + " value=default_email\n", + " )\n", + " imap_password = gr.Textbox(\n", + " label=\"App Password\",\n", + " type=\"password\",\n", + " placeholder=\"16-character app password\",\n", + " value=default_password\n", + " )\n", + "\n", + " imap_btn = gr.Button(\"Connect\", variant=\"primary\")\n", + " imap_status = gr.Textbox(label=\"Connection Status\", lines=3)\n", + "\n", + " gr.Markdown(\"---\")\n", + " gr.Markdown(\"## Process Emails\")\n", + "\n", + " with gr.Row():\n", + " testing_mode_check = gr.Checkbox(label=\"Testing Mode (50 emails only)\", value=True)\n", + " embedding_dropdown = gr.Dropdown(\n", + " choices=[\"openai\", \"bert\"],\n", + " value=\"openai\",\n", + " label=\"Embedding Model\"\n", + " )\n", + "\n", + " process_btn = gr.Button(\"📥 Fetch and Process Emails\", variant=\"primary\", size=\"lg\")\n", + " process_status = gr.Textbox(label=\"Processing Status\", lines=10)\n", + "\n", + " imap_btn.click(connect_imap, inputs=[imap_email, imap_password], outputs=imap_status)\n", + " process_btn.click(\n", + " fetch_and_process,\n", + " inputs=[testing_mode_check, embedding_dropdown],\n", + " outputs=process_status\n", + " )\n", + "\n", + " # Tab 2: Topic Analysis & Configuration\n", + " with gr.Tab(\"🔍 Topic Analysis & Configuration\"):\n", + " gr.Markdown(\"## a) Configuration\")\n", + "\n", + " with gr.Row():\n", + " llm_dropdown = gr.Dropdown(\n", + " choices=[\"openai\", \"gemini\"],\n", + " value=\"openai\",\n", + " label=\"LLM Model\"\n", + " )\n", + "\n", + " classification_threshold = gr.Slider(\n", + " minimum=0.1,\n", + " maximum=0.9,\n", + " value=0.5,\n", + " step=0.1,\n", + " label=\"Relevance Threshold (higher = more strict, fewer kept)\"\n", + " )\n", + "\n", + " gr.Markdown(\"---\")\n", + " gr.Markdown(\"## b) Interest Analysis\")\n", + " gr.Markdown(\"Identify topics that are IMPORTANT to you. Emails matching these topics will be KEPT, others offered for deletion.\")\n", + "\n", + " analyze_btn = gr.Button(\"🤖 Identify My Interests\", variant=\"primary\", size=\"lg\")\n", + " topics_output = gr.Textbox(label=\"Important Topics\", lines=10)\n", + " counts_output = gr.Textbox(label=\"Category Counts\", lines=8)\n", + "\n", + " gr.Markdown(\"---\")\n", + " gr.Markdown(\"### Refine Topics with LLM Query\")\n", + " gr.Markdown(\"Ask the LLM to identify specific topics based on your interests. Results replace topics above.\")\n", + "\n", + " with gr.Row():\n", + " chat_query_input = gr.Textbox(\n", + " label=\"Query about your interests\",\n", + " placeholder=\"e.g., 'What are my most important professional topics?'\",\n", + " scale=3\n", + " )\n", + " chat_submit_btn = gr.Button(\"Submit Query\", variant=\"secondary\", scale=1)\n", + "\n", + " gr.Markdown(\"\"\"\n", + "**Example queries:**\n", + "- \"What are my most important professional topics?\"\n", + "- \"Identify topics related to family and personal life\"\n", + "- \"What work-related topics should I keep?\"\n", + "\"\"\")\n", + "\n", + " # Tab 3: Email Management & Deletion\n", + " with gr.Tab(\"📧 Email Management & Deletion\"):\n", + " gr.Markdown(\"## Classified Emails based on topic analysi)\")\n", + " gr.Markdown(\"Emails matching your important topics are in 'Keep'. Others are deletion candidates.\")\n", + "\n", + " with gr.Row():\n", + " with gr.Column():\n", + " gr.Markdown(\"### 📌 Keep (Important)\")\n", + " keep_df = gr.Dataframe(label=\"Emails to Keep\", interactive=False)\n", + "\n", + " with gr.Column():\n", + " gr.Markdown(\"### 🗑️ Delete Candidates\")\n", + "\n", + " with gr.Row():\n", + " select_all_btn = gr.Button(\"✅ Select All\", size=\"sm\")\n", + " deselect_all_btn = gr.Button(\"❌ Deselect All\", size=\"sm\")\n", + "\n", + " delete_df = gr.Dataframe(\n", + " label=\"Select emails to delete\",\n", + " interactive=True,\n", + " datatype=[\"bool\", \"str\", \"str\", \"str\", \"str\", \"number\"],\n", + " col_count=(6, \"fixed\")\n", + " )\n", + "\n", + " select_all_btn.click(select_all_emails, inputs=delete_df, outputs=delete_df)\n", + " deselect_all_btn.click(deselect_all_emails, inputs=delete_df, outputs=delete_df)\n", + "\n", + " gr.Markdown(\"---\")\n", + " gr.Markdown(\"## Archive & Delete\")\n", + "\n", + " with gr.Row():\n", + " archive_btn = gr.Button(\"📦 Create Archive\", variant=\"secondary\")\n", + " delete_btn = gr.Button(\"🔥 DELETE SELECTED\", variant=\"stop\")\n", + "\n", + " with gr.Row():\n", + " with gr.Column():\n", + " archive_status = gr.Textbox(label=\"Archive Status\", lines=2)\n", + " with gr.Column():\n", + " confirmation_input = gr.Textbox(label=\"Type DELETE to confirm\", placeholder=\"DELETE\")\n", + "\n", + " archive_file = gr.File(label=\"Download Archive\")\n", + " deletion_status = gr.Textbox(label=\"Deletion Result\", lines=3)\n", + "\n", + " analyze_btn.click(\n", + " analyze_topics,\n", + " inputs=[llm_dropdown, classification_threshold],\n", + " outputs=[topics_output, counts_output, keep_df, delete_df]\n", + " )\n", + "\n", + " chat_submit_btn.click(\n", + " refine_topics_with_chat,\n", + " inputs=[chat_query_input, llm_dropdown, classification_threshold],\n", + " outputs=[topics_output, counts_output, keep_df, delete_df]\n", + " )\n", + "\n", + " archive_btn.click(create_archive_file, outputs=[archive_status, archive_file])\n", + " delete_btn.click(perform_deletion, inputs=[confirmation_input, delete_df], outputs=deletion_status)" + ], + "metadata": { + "id": "iKC3MtzX-jVT" + }, + "id": "iKC3MtzX-jVT", + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "## Launch App" + ], + "metadata": { + "id": "rY9Pbte__Kqa" + }, + "id": "rY9Pbte__Kqa" + }, + { + "cell_type": "code", + "source": [ + "app.launch(share=True, inbrowser=True)" + ], + "metadata": { + "id": "YUHF1ZIl_Nv-" + }, + "id": "YUHF1ZIl_Nv-", + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "##Unit Tests for Components" + ], + "metadata": { + "id": "jHgVYNTc-tCf" + }, + "id": "jHgVYNTc-tCf" + }, + { + "cell_type": "code", + "source": [ + "\n", + "print(\"=\" * 60)\n", + "print(\"UNIT TESTS - Testing Individual Components\")\n", + "print(\"=\" * 60)\n", + "\n", + "# Test 1: Helper Functions\n", + "print(\"\\n📝 Test 1: Helper Functions\")\n", + "print(\"-\" * 40)\n", + "\n", + "def test_helper_functions():\n", + " \"\"\"Test email parsing helper functions.\"\"\"\n", + " # Test get_header_value\n", + " test_headers = [\n", + " {'name': 'Subject', 'value': 'Test Email'},\n", + " {'name': 'From', 'value': 'sender@example.com'},\n", + " {'name': 'Date', 'value': '2025-10-21'}\n", + " ]\n", + "\n", + " assert get_header_value(test_headers, 'Subject') == 'Test Email'\n", + " assert get_header_value(test_headers, 'From') == 'sender@example.com'\n", + " assert get_header_value(test_headers, 'Missing') == ''\n", + "\n", + " print(\"✓ get_header_value() works correctly\")\n", + " return True\n", + "\n", + "try:\n", + " test_helper_functions()\n", + " print(\"\\n✅ Helper functions test PASSED\")\n", + "except AssertionError as e:\n", + " print(f\"\\n❌ Helper functions test FAILED: {e}\")\n", + "\n", + "# Test 2: VectorDatabaseManager\n", + "print(\"\\n\\n💾 Test 2: VectorDatabaseManager\")\n", + "print(\"-\" * 40)\n", + "\n", + "def test_vector_database_manager():\n", + " \"\"\"Test VectorDatabaseManager class.\"\"\"\n", + " test_docs = [\n", + " Document(\n", + " page_content=\"This is a test email about Python programming and data science.\",\n", + " metadata={'subject': 'Test 1', 'sender': 'test@example.com'}\n", + " ),\n", + " Document(\n", + " page_content=\"Another email discussing machine learning and AI topics.\",\n", + " metadata={'subject': 'Test 2', 'sender': 'ai@example.com'}\n", + " ),\n", + " Document(\n", + " page_content=\"Meeting invitation for tomorrow's project review.\",\n", + " metadata={'subject': 'Test 3', 'sender': 'manager@example.com'}\n", + " )\n", + " ]\n", + "\n", + " test_mgr = VectorDatabaseManager(db_name=\"test_vector_db\")\n", + " embeddings = test_mgr.create_embeddings(\"bert\")\n", + " assert test_mgr.embeddings is not None\n", + " print(\"✓ Embeddings created successfully\")\n", + "\n", + " vectorstore = test_mgr.create_vector_store(test_docs, recreate=True)\n", + " assert vectorstore is not None\n", + " assert test_mgr.vectorstore._collection.count() == len(test_docs)\n", + " print(f\"✓ Vector store created with {len(test_docs)} documents\")\n", + "\n", + " retriever = vectorstore.as_retriever(search_kwargs={\"k\": 2})\n", + " results = retriever.invoke(\"Python programming\")\n", + " assert len(results) > 0\n", + " print(f\"✓ Retrieval works: found {len(results)} relevant documents\")\n", + "\n", + " if os.path.exists(\"test_vector_db\"):\n", + " shutil.rmtree(\"test_vector_db\")\n", + "\n", + " return True\n", + "\n", + "try:\n", + " test_vector_database_manager()\n", + " print(\"\\n✅ VectorDatabaseManager test PASSED\")\n", + "except Exception as e:\n", + " print(f\"\\n❌ VectorDatabaseManager test FAILED: {e}\")\n", + "\n", + "# Test 3: EmailProcessor\n", + "print(\"\\n\\n📧 Test 3: EmailProcessor\")\n", + "print(\"-\" * 40)\n", + "\n", + "def test_email_processor():\n", + " \"\"\"Test EmailProcessor class.\"\"\"\n", + " test_docs = [\n", + " Document(\n", + " page_content=\"Subject: Project Update\\nFrom: boss@company.com\\nTo: me@company.com\\nDate: 2025-10-20\\n\\nPlease review the quarterly report.\",\n", + " metadata={'subject': 'Project Update', 'sender': 'boss@company.com', 'message_id': '001', 'date': '2025-10-20'}\n", + " ),\n", + " Document(\n", + " page_content=\"Subject: Newsletter\\nFrom: marketing@spam.com\\nTo: me@company.com\\nDate: 2025-10-19\\n\\nCheck out our latest deals!\",\n", + " metadata={'subject': 'Newsletter', 'sender': 'marketing@spam.com', 'message_id': '002', 'date': '2025-10-19'}\n", + " ),\n", + " Document(\n", + " page_content=\"Subject: Team Meeting\\nFrom: colleague@company.com\\nTo: me@company.com\\nDate: 2025-10-21\\n\\nMeeting tomorrow at 10am.\",\n", + " metadata={'subject': 'Team Meeting', 'sender': 'colleague@company.com', 'message_id': '003', 'date': '2025-10-21'}\n", + " )\n", + " ]\n", + "\n", + " processor = EmailProcessor()\n", + "\n", + " chunks = processor.chunk_documents(test_docs, chunk_size=100, chunk_overlap=20)\n", + " assert len(chunks) >= len(test_docs)\n", + " print(f\"✓ Chunking works: created {len(chunks)} chunks from {len(test_docs)} documents\")\n", + "\n", + " stats = processor.get_statistics(test_docs)\n", + " assert stats['total_emails'] == 3\n", + " assert stats['unique_senders'] == 3\n", + " print(f\"✓ Statistics calculation works: {stats['total_emails']} emails, {stats['unique_senders']} unique senders\")\n", + "\n", + " df = processor.emails_to_dataframe(test_docs, add_select_column=True)\n", + " assert len(df) == 3\n", + " assert 'Topics' in df.columns\n", + " assert 'Subject' in df.columns\n", + " assert 'Sender' in df.columns\n", + " assert 'Select' in df.columns\n", + " print(f\"✓ DataFrame conversion works: {len(df)} rows, {len(df.columns)} columns\")\n", + "\n", + " return True\n", + "\n", + "try:\n", + " test_email_processor()\n", + " print(\"\\n✅ EmailProcessor test PASSED\")\n", + "except Exception as e:\n", + " print(f\"\\n❌ EmailProcessor test FAILED: {e}\")\n", + "\n", + "# Test 4: Mock IMAP Connection\n", + "print(\"\\n\\n🔌 Test 4: Mock IMAP Connection\")\n", + "print(\"-\" * 40)\n", + "\n", + "def test_mock_connection():\n", + " \"\"\"Test the connection interface with a mock implementation.\"\"\"\n", + "\n", + " class MockIMAPConnection(GmailConnection):\n", + " \"\"\"Mock implementation for testing.\"\"\"\n", + "\n", + " def connect(self) -> bool:\n", + " self.auth_info = {\n", + " 'email': 'test@example.com',\n", + " 'total_messages': 100,\n", + " 'auth_method': 'Mock'\n", + " }\n", + " self.connection = \"mock_connection\"\n", + " return True\n", + "\n", + " def fetch_emails(self, max_emails: Optional[int] = None) -> Tuple[List[Document], str]:\n", + " limit = max_emails if max_emails else 10\n", + " docs = [\n", + " Document(\n", + " page_content=f\"Mock email {i}\",\n", + " metadata={\n", + " 'message_id': f'mock_{i}',\n", + " 'subject': f'Test Subject {i}',\n", + " 'sender': f'sender{i}@example.com',\n", + " 'date': '2025-10-21'\n", + " }\n", + " )\n", + " for i in range(min(limit, 5))\n", + " ]\n", + " return docs, f\"✓ Fetched {len(docs)} mock emails\"\n", + "\n", + " def delete_emails(self, documents: List[Document]) -> Tuple[int, int]:\n", + " return len(documents), 0\n", + "\n", + " mock_conn = MockIMAPConnection()\n", + "\n", + " assert mock_conn.connect()\n", + " print(\"✓ Mock connection established\")\n", + "\n", + " assert mock_conn.is_connected()\n", + " print(\"✓ Connection status check works\")\n", + "\n", + " info = mock_conn.get_auth_info()\n", + " assert info['email'] == 'test@example.com'\n", + " print(f\"✓ Auth info retrieved: {info['email']}\")\n", + "\n", + " emails, diagnostics = mock_conn.fetch_emails(max_emails=3)\n", + " assert len(emails) == 3\n", + " print(f\"✓ Fetched {len(emails)} mock emails\")\n", + " print(f\" Diagnostics: {diagnostics}\")\n", + "\n", + " success, failed = mock_conn.delete_emails(emails)\n", + " assert success == 3 and failed == 0\n", + " print(f\"✓ Mock deletion: {success} successful, {failed} failed\")\n", + "\n", + " return True\n", + "\n", + "try:\n", + " test_mock_connection()\n", + " print(\"\\n✅ Mock connection test PASSED\")\n", + "except Exception as e:\n", + " print(f\"\\n❌ Mock connection test FAILED: {e}\")\n", + "\n", + "print(\"\\n\" + \"=\" * 60)\n", + "print(\"✅ ALL UNIT TESTS COMPLETED\")\n", + "print(\"=\" * 60)\n" + ], + "metadata": { + "id": "NQjxVtZl-sNm" + }, + "id": "NQjxVtZl-sNm", + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "##Integration Test (with Mock Data)" + ], + "metadata": { + "id": "sA6A8f2Q-r_2" + }, + "id": "sA6A8f2Q-r_2" + }, + { + "cell_type": "code", + "source": [ + "print(\"\\n\\n\" + \"=\" * 60)\n", + "print(\"INTEGRATION TEST - Full Workflow with Mock Data\")\n", + "print(\"=\" * 60)\n", + "\n", + "def run_integration_test():\n", + " \"\"\"Run a complete workflow test with mock data.\"\"\"\n", + "\n", + " print(\"\\n🚀 Starting integration test...\")\n", + "\n", + " # Step 1: Create mock connection\n", + " print(\"\\n1️⃣ Creating mock Gmail connection...\")\n", + "\n", + " class TestGmailConnection(GmailConnection):\n", + " def connect(self):\n", + " self.connection = True\n", + " self.auth_info = {'email': 'test@example.com', 'total_messages': 20, 'auth_method': 'Test'}\n", + " return True\n", + "\n", + " def fetch_emails(self, max_emails=None):\n", + " # Generate realistic mock emails\n", + " topics = [\n", + " (\"Work Project\", \"manager@company.com\", \"Need your input on Q4 planning and budget allocation.\"),\n", + " (\"Team Meeting\", \"colleague@company.com\", \"Weekly sync tomorrow at 10am to discuss progress.\"),\n", + " (\"Newsletter\", \"marketing@newsletter.com\", \"Top 10 deals this week! Don't miss out!\"),\n", + " (\"Spam Offer\", \"deals@promo.com\", \"You've won a million dollars! Click here now!\"),\n", + " (\"Client Update\", \"client@business.com\", \"Regarding the proposal you sent last week.\"),\n", + " (\"Training Course\", \"learning@company.com\", \"New Python course available for employees.\"),\n", + " (\"Marketing Email\", \"ads@shopping.com\", \"Summer sale - 50% off everything!\"),\n", + " (\"Boss Email\", \"ceo@company.com\", \"Great job on the presentation yesterday!\"),\n", + " (\"Junk\", \"random@spam.com\", \"Make money fast with this one weird trick!\"),\n", + " (\"Important Notice\", \"hr@company.com\", \"Annual review meeting scheduled for next month.\")\n", + " ]\n", + "\n", + " limit = min(max_emails if max_emails else 10, len(topics))\n", + "\n", + " docs = [\n", + " Document(\n", + " page_content=f\"Subject: {subj}\\nFrom: {sender}\\nTo: test@example.com\\nDate: 2025-10-{20-i}\\n\\n{body}\",\n", + " metadata={\n", + " 'message_id': f'test_{i}',\n", + " 'subject': subj,\n", + " 'sender': sender,\n", + " 'recipient': 'test@example.com',\n", + " 'date': f'2025-10-{20-i}',\n", + " 'source': 'test'\n", + " }\n", + " )\n", + " for i, (subj, sender, body) in enumerate(topics[:limit])\n", + " ]\n", + " return docs, f\"✓ Fetched {len(docs)} test emails\"\n", + "\n", + " def delete_emails(self, documents):\n", + " return len(documents), 0\n", + "\n", + " test_conn = TestGmailConnection()\n", + " test_conn.connect()\n", + " print(f\" ✓ Connected as: {test_conn.get_auth_info()['email']}\")\n", + "\n", + " # Step 2: Fetch emails\n", + " print(\"\\n2️⃣ Fetching mock emails...\")\n", + " emails, diagnostics = test_conn.fetch_emails(max_emails=10)\n", + " print(f\" ✓ Fetched {len(emails)} emails\")\n", + " print(f\" {diagnostics}\")\n", + "\n", + " # Step 3: Process emails\n", + " print(\"\\n3️⃣ Processing emails...\")\n", + " processor = EmailProcessor()\n", + " chunks = processor.chunk_documents(emails)\n", + " print(f\" ✓ Created {len(chunks)} chunks\")\n", + "\n", + " stats = processor.get_statistics(emails)\n", + " print(f\" ✓ Statistics: {stats['total_emails']} emails, {stats['unique_senders']} senders\")\n", + "\n", + " # Step 4: Create vector store\n", + " print(\"\\n4️⃣ Creating vector store...\")\n", + " vector_mgr = VectorDatabaseManager(db_name=\"test_integration_db\")\n", + " vector_mgr.create_embeddings(\"bert\") # Use BERT to avoid API costs\n", + " vector_mgr.create_vector_store(chunks, recreate=True)\n", + " print(f\" ✓ Vector store created with {vector_mgr.vectorstore._collection.count()} documents\")\n", + "\n", + " # Step 5: Analyze topics (simulated - would normally use LLM)\n", + " print(\"\\n5️⃣ Analyzing topics...\")\n", + " processor.topics = \"\"\"\n", + "Based on the email analysis:\n", + "1. Work Projects - Manager communications about planning and budgets\n", + "2. Team Collaboration - Meeting invites and team sync-ups\n", + "3. Client Relations - Important client communications\n", + "4. Professional Development - Training and learning opportunities\n", + "5. Company Announcements - HR and leadership communications\n", + "\"\"\"\n", + " print(\" Topics identified (mock analysis)\")\n", + "\n", + " # Step 6: Classify emails\n", + " print(\"\\n6️⃣ Classifying emails...\")\n", + " # Simulate classification based on sender domains\n", + " work_domains = ['company.com', 'business.com']\n", + " spam_domains = ['newsletter.com', 'promo.com', 'spam.com', 'shopping.com']\n", + "\n", + " keep_emails = [email for email in emails if any(domain in email.metadata.get('sender', '') for domain in work_domains)]\n", + " delete_emails = [email for email in emails if any(domain in email.metadata.get('sender', '') for domain in spam_domains)]\n", + "\n", + " processor.classified_emails = {'keep': keep_emails, 'delete': delete_emails}\n", + " print(f\" ✓ Classification complete:\")\n", + " print(f\" - Keep: {len(keep_emails)} emails\")\n", + " print(f\" - Delete: {len(delete_emails)} emails\")\n", + "\n", + " # Step 7: Create archive\n", + " print(\"\\n7️⃣ Creating archive...\")\n", + " if delete_emails:\n", + " archive_path = processor.create_archive(delete_emails)\n", + " print(f\" ✓ Archive created: {archive_path}\")\n", + " archive_exists = os.path.exists(archive_path)\n", + " print(f\" ✓ Archive file exists: {archive_exists}\")\n", + "\n", + " # Step 8: Simulate deletion\n", + " print(\"\\n8️⃣ Simulating deletion...\")\n", + " success, failed = test_conn.delete_emails(delete_emails)\n", + " print(f\" ✓ Deletion complete: {success} successful, {failed} failed\")\n", + "\n", + " # Step 9: Display results as DataFrame\n", + " print(\"\\n9️⃣ Generating reports...\")\n", + " keep_df = processor.emails_to_dataframe(keep_emails)\n", + " delete_df = processor.emails_to_dataframe(delete_emails)\n", + " print(f\" ✓ Keep DataFrame: {len(keep_df)} rows\")\n", + " print(f\" ✓ Delete DataFrame: {len(delete_df)} rows\")\n", + "\n", + " # Cleanup\n", + " print(\"\\n🧹 Cleaning up test files...\")\n", + " if os.path.exists(\"test_integration_db\"):\n", + " shutil.rmtree(\"test_integration_db\")\n", + " if delete_emails and os.path.exists(archive_path):\n", + " os.remove(archive_path)\n", + " print(\" ✓ Cleanup complete\")\n", + "\n", + " print(\"\\n\" + \"=\" * 60)\n", + " print(\"✅ INTEGRATION TEST COMPLETED SUCCESSFULLY!\")\n", + " print(\"=\" * 60)\n", + " print(\"\\n📊 Summary:\")\n", + " print(f\" • Total emails processed: {len(emails)}\")\n", + " print(f\" • Emails to keep: {len(keep_emails)}\")\n", + " print(f\" • Emails to delete: {len(delete_emails)}\")\n", + " print(f\" • Archive created: ✓\")\n", + " print(f\" • Deletion simulated: ✓\")\n", + " print(\"\\n💡 The refactored architecture makes testing easy!\")\n", + "\n", + " return True\n", + "\n", + "try:\n", + " run_integration_test()\n", + "except Exception as e:\n", + " print(f\"\\n❌ INTEGRATION TEST FAILED: {e}\")\n", + " import traceback\n", + " traceback.print_exc()" + ], + "metadata": { + "id": "5MBAXKSW-9qp" + }, + "id": "5MBAXKSW-9qp", + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "##Performance Test" + ], + "metadata": { + "id": "zpaJTrOp_BdP" + }, + "id": "zpaJTrOp_BdP" + }, + { + "cell_type": "code", + "source": [ + "\n", + "print(\"\\n\\n\" + \"=\" * 60)\n", + "print(\"PERFORMANCE TEST - Component Benchmarks\")\n", + "print(\"=\" * 60)\n", + "\n", + "import time\n", + "\n", + "def benchmark_component(name, func, *args, **kwargs):\n", + " \"\"\"Benchmark a component function.\"\"\"\n", + " start = time.time()\n", + " result = func(*args, **kwargs)\n", + " elapsed = time.time() - start\n", + " print(f\" {name}: {elapsed:.3f}s\")\n", + " return result, elapsed\n", + "\n", + "def run_performance_tests():\n", + " \"\"\"Run performance benchmarks.\"\"\"\n", + "\n", + " # Generate test data\n", + " print(\"\\n📊 Generating test data...\")\n", + " test_emails = [\n", + " Document(\n", + " page_content=f\"Subject: Test {i}\\nFrom: sender{i % 10}@example.com\\n\\n\" + \" \".join([\"word\"] * 100),\n", + " metadata={\n", + " 'message_id': f'perf_{i}',\n", + " 'subject': f'Test {i}',\n", + " 'sender': f'sender{i % 10}@example.com',\n", + " 'date': f'2025-10-{(i % 30) + 1:02d}'\n", + " }\n", + " )\n", + " for i in range(100)\n", + " ]\n", + " print(f\" ✓ Created {len(test_emails)} test emails\")\n", + "\n", + " # Benchmark EmailProcessor\n", + " print(\"\\n⏱️ Benchmarking EmailProcessor...\")\n", + " processor = EmailProcessor()\n", + "\n", + " chunks, t1 = benchmark_component(\"Chunking\", processor.chunk_documents, test_emails)\n", + " stats, t2 = benchmark_component(\"Statistics\", processor.get_statistics, test_emails)\n", + " df, t3 = benchmark_component(\"DataFrame conversion\", processor.emails_to_dataframe, test_emails)\n", + "\n", + " # Benchmark VectorDatabaseManager\n", + " print(\"\\n⏱️ Benchmarking VectorDatabaseManager...\")\n", + " vector_mgr = VectorDatabaseManager(db_name=\"test_perf_db\")\n", + "\n", + " emb, t4 = benchmark_component(\"Embedding creation\", vector_mgr.create_embeddings, \"bert\")\n", + " vs, t5 = benchmark_component(\"Vector store creation\", vector_mgr.create_vector_store, chunks[:50]) # Limit for speed\n", + "\n", + " # Cleanup\n", + " if os.path.exists(\"test_perf_db\"):\n", + " shutil.rmtree(\"test_perf_db\")\n", + "\n", + " print(\"\\n\" + \"=\" * 60)\n", + " print(\"✅ PERFORMANCE TEST COMPLETED\")\n", + " print(\"=\" * 60)\n", + " print(f\"\\n📈 Total time: {t1 + t2 + t3 + t4 + t5:.3f}s\")\n", + " print(f\" Fastest operation: DataFrame conversion ({t3:.3f}s)\")\n", + " print(f\" Slowest operation: Vector store creation ({t5:.3f}s)\")\n", + "\n", + "try:\n", + " run_performance_tests()\n", + "except Exception as e:\n", + " print(f\"\\n❌ PERFORMANCE TEST FAILED: {e}\")\n", + "\n" + ], + "metadata": { + "id": "41w8FGJ9_CCU" + }, + "id": "41w8FGJ9_CCU", + "execution_count": null, + "outputs": [] + } + ], + "metadata": { + "language_info": { + "name": "python" + }, + "colab": { + "provenance": [], + "include_colab_link": true + }, + "kernelspec": { + "name": "python3", + "display_name": "Python 3" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} \ No newline at end of file diff --git a/week5/community-contributions/w5d5_worker.py b/week5/community-contributions/w5d5_worker.py new file mode 100644 index 0000000..822cfcd --- /dev/null +++ b/week5/community-contributions/w5d5_worker.py @@ -0,0 +1,445 @@ +#!/usr/bin/env python3 +""" +Knowledge Worker with Document Upload and Google Drive Integration + +This script creates a knowledge worker that: +1. Allows users to upload documents through a Gradio UI +2. Integrates with Google Drive to access documents +3. Uses Chroma vector database for efficient document retrieval +4. Implements RAG (Retrieval Augmented Generation) for accurate responses + +The system updates its context dynamically when new documents are uploaded. +""" + +import os +import glob +import tempfile +from pathlib import Path +from dotenv import load_dotenv +import gradio as gr + +# LangChain imports +from langchain_community.document_loaders import DirectoryLoader, TextLoader, PyPDFLoader +from langchain_core.documents import Document +from langchain_openai import OpenAIEmbeddings, ChatOpenAI +from langchain_chroma import Chroma + +# Visualization imports +import numpy as np +from sklearn.manifold import TSNE +import plotly.graph_objects as go + +# Removed Google Drive API imports + +# Additional document loaders +try: + from langchain_community.document_loaders import Docx2txtLoader, UnstructuredExcelLoader +except ImportError: + print("Warning: Some document loaders not available. PDF and text files will still work.") + Docx2txtLoader = None + UnstructuredExcelLoader = None + +# Configuration +MODEL = "gpt-4o-mini" # Using a cost-effective model +DB_NAME = "knowledge_worker_db" +UPLOAD_FOLDER = "uploaded_documents" + +# Create upload folder if it doesn't exist +os.makedirs(UPLOAD_FOLDER, exist_ok=True) + +# Load environment variables +load_dotenv(override=True) +os.environ['OPENAI_API_KEY'] = os.getenv('OPENAI_API_KEY', 'your-key-if-not-using-env') + +# Removed Google Drive credentials configuration + +# Use a simple text splitter approach +class SimpleTextSplitter: + def __init__(self, chunk_size=1000, chunk_overlap=200): + self.chunk_size = chunk_size + self.chunk_overlap = chunk_overlap + + def split_documents(self, documents): + chunks = [] + for doc in documents: + text = doc.page_content + start = 0 + while start < len(text): + end = start + self.chunk_size + chunk_text = text[start:end] + chunk_doc = Document(page_content=chunk_text, metadata=doc.metadata.copy()) + chunks.append(chunk_doc) + start = end - self.chunk_overlap + return chunks + +CharacterTextSplitter = SimpleTextSplitter + +# Try different import paths for memory and chains +try: + from langchain.memory import ConversationBufferMemory + from langchain.chains import ConversationalRetrievalChain +except ImportError: + try: + from langchain_core.memory import ConversationBufferMemory + from langchain_core.chains import ConversationalRetrievalChain + except ImportError: + try: + from langchain_community.memory import ConversationBufferMemory + from langchain_community.chains import ConversationalRetrievalChain + except ImportError: + print("Warning: Memory and chains modules not found. Creating simple alternatives.") + # Create simple alternatives + class ConversationBufferMemory: + def __init__(self, memory_key='chat_history', return_messages=True): + self.memory_key = memory_key + self.return_messages = return_messages + self.chat_memory = [] + + def save_context(self, inputs, outputs): + self.chat_memory.append((inputs, outputs)) + + def load_memory_variables(self, inputs): + return {self.memory_key: self.chat_memory} + + class ConversationalRetrievalChain: + def __init__(self, llm, retriever, memory): + self.llm = llm + self.retriever = retriever + self.memory = memory + + def invoke(self, inputs): + question = inputs.get("question", "") + # Simple implementation - just return a basic response + return {"answer": f"I received your question: {question}. This is a simplified response."} + +# Removed Google Drive Integration Functions + +# Document Processing Functions +def get_loader_for_file(file_path): + """ + Get the appropriate document loader based on file extension + """ + file_extension = os.path.splitext(file_path)[1].lower() + + if file_extension == '.pdf': + return PyPDFLoader(file_path) + elif file_extension in ['.docx', '.doc'] and Docx2txtLoader: + return Docx2txtLoader(file_path) + elif file_extension in ['.xlsx', '.xls'] and UnstructuredExcelLoader: + return UnstructuredExcelLoader(file_path) + elif file_extension in ['.txt', '.md']: + return TextLoader(file_path, encoding='utf-8') + else: + # Default to text loader for unknown types + try: + return TextLoader(file_path, encoding='utf-8') + except: + return None + +def load_document(file_path): + """ + Load a document using the appropriate loader + """ + loader = get_loader_for_file(file_path) + if loader: + try: + return loader.load() + except Exception as e: + print(f"Error loading document {file_path}: {e}") + return [] + +def process_documents(documents): + """ + Split documents into chunks for embedding + """ + text_splitter = CharacterTextSplitter( + chunk_size=1000, + chunk_overlap=200 + ) + chunks = text_splitter.split_documents(documents) + return chunks + +# Knowledge Base Class +class KnowledgeBase: + def __init__(self, db_name=DB_NAME): + self.db_name = db_name + self.embeddings = OpenAIEmbeddings() + self.vectorstore = None + self.initialize_vectorstore() + + def initialize_vectorstore(self): + """ + Initialize the vector store, loading from disk if it exists + """ + if os.path.exists(self.db_name): + self.vectorstore = Chroma(persist_directory=self.db_name, embedding_function=self.embeddings) + print(f"Loaded existing vector store with {self.vectorstore._collection.count()} documents") + else: + # Create empty vectorstore + self.vectorstore = Chroma(persist_directory=self.db_name, embedding_function=self.embeddings) + print("Created new vector store") + + def add_documents(self, documents): + """ + Process and add documents to the vector store + """ + if not documents: + return False + + chunks = process_documents(documents) + if not chunks: + return False + + # Add to existing vectorstore + self.vectorstore.add_documents(chunks) + print(f"Added {len(chunks)} chunks to vector store") + return True + + def get_retriever(self, k=4): + """ + Get a retriever for the vector store + """ + return self.vectorstore.as_retriever(search_kwargs={"k": k}) + + def visualize_vectors(self): + """ + Create a 3D visualization of the vector store + """ + try: + collection = self.vectorstore._collection + result = collection.get(include=['embeddings', 'documents', 'metadatas']) + + if result['embeddings'] is None or len(result['embeddings']) == 0: + print("No embeddings found in vector store") + return None + + vectors = np.array(result['embeddings']) + documents = result['documents'] + metadatas = result['metadatas'] + + if len(vectors) < 2: + print("Not enough vectors for visualization (need at least 2)") + return None + + # Get source info for coloring + sources = [metadata.get('source', 'unknown') for metadata in metadatas] + unique_sources = list(set(sources)) + colors = [['blue', 'green', 'red', 'orange', 'purple', 'cyan'][unique_sources.index(s) % 6] for s in sources] + + # Reduce dimensions for visualization + # Adjust perplexity based on number of samples + n_samples = len(vectors) + perplexity = min(30, max(1, n_samples - 1)) + + tsne = TSNE(n_components=3, random_state=42, perplexity=perplexity) + reduced_vectors = tsne.fit_transform(vectors) + + # Create the 3D scatter plot + fig = go.Figure(data=[go.Scatter3d( + x=reduced_vectors[:, 0], + y=reduced_vectors[:, 1], + z=reduced_vectors[:, 2], + mode='markers', + marker=dict(size=5, color=colors, opacity=0.8), + text=[f"Source: {s}
Text: {d[:100]}..." for s, d in zip(sources, documents)], + hoverinfo='text' + )]) + + fig.update_layout( + title='3D Vector Store Visualization', + scene=dict(xaxis_title='x', yaxis_title='y', zaxis_title='z'), + width=900, + height=700, + margin=dict(r=20, b=10, l=10, t=40) + ) + + return fig + + except Exception as e: + print(f"Error creating visualization: {e}") + return None + +# Simple fallback chain implementation +class SimpleConversationalChain: + def __init__(self, llm, retriever, memory): + self.llm = llm + self.retriever = retriever + self.memory = memory + + def invoke(self, inputs): + question = inputs.get("question", "") + # Get relevant documents - try different methods + try: + docs = self.retriever.get_relevant_documents(question) + except AttributeError: + try: + docs = self.retriever.invoke(question) + except: + docs = [] + + context = "\n".join([doc.page_content for doc in docs[:3]]) if docs else "No relevant context found." + + # Create a simple prompt + prompt = f"""Based on the following context, answer the question: + +Context: {context} + +Question: {question} + +Answer:""" + + # Get response from LLM + response = self.llm.invoke(prompt) + return {"answer": response.content if hasattr(response, 'content') else str(response)} + +# Chat System Class +class ChatSystem: + def __init__(self, knowledge_base, model_name=MODEL): + self.knowledge_base = knowledge_base + self.model_name = model_name + self.llm = ChatOpenAI(temperature=0.7, model_name=self.model_name) + self.memory = ConversationBufferMemory(memory_key='chat_history', return_messages=True) + self.conversation_chain = self._create_conversation_chain() + + def _create_conversation_chain(self): + """ + Create a new conversation chain with the current retriever + """ + retriever = self.knowledge_base.get_retriever() + # Skip the problematic ConversationalRetrievalChain and use simple implementation + print("Using simple conversational chain implementation") + return SimpleConversationalChain(self.llm, retriever, self.memory) + + def reset_conversation(self): + """ + Reset the conversation memory and chain + """ + self.memory = ConversationBufferMemory(memory_key='chat_history', return_messages=True) + self.conversation_chain = self._create_conversation_chain() + return "Conversation has been reset." + + def chat(self, question, history): + """ + Process a question and return the answer + """ + if not question.strip(): + return "Please ask a question." + + result = self.conversation_chain.invoke({"question": question}) + return result["answer"] + + def update_knowledge_base(self): + """ + Update the conversation chain with the latest knowledge base + """ + self.conversation_chain = self._create_conversation_chain() + +# UI Functions +def handle_file_upload(files): + """ + Process uploaded files and add them to the knowledge base + """ + if not files: + return "No files uploaded." + + documents = [] + for file in files: + try: + docs = load_document(file.name) + if docs: + # Add upload source metadata + for doc in docs: + doc.metadata['source'] = 'upload' + doc.metadata['filename'] = os.path.basename(file.name) + documents.extend(docs) + except Exception as e: + print(f"Error processing file {file.name}: {e}") + + if documents: + success = kb.add_documents(documents) + if success: + # Update the chat system with new knowledge + chat_system.update_knowledge_base() + return f"Successfully processed {len(documents)} documents." + + return "No documents could be processed. Please check file formats." + +def create_ui(): + """ + Create the Gradio UI + """ + with gr.Blocks(theme=gr.themes.Soft()) as app: + gr.Markdown(""" + # Knowledge Worker + Upload documents or ask questions about your knowledge base. + """) + + with gr.Tabs(): + with gr.TabItem("Chat"): + chatbot = gr.ChatInterface( + chat_system.chat, + chatbot=gr.Chatbot(height=500, type="messages"), + textbox=gr.Textbox(placeholder="Ask a question about your documents...", container=False), + title="Knowledge Worker Chat", + type="messages" + ) + reset_btn = gr.Button("Reset Conversation") + reset_btn.click(chat_system.reset_conversation, inputs=None, outputs=gr.Textbox()) + + with gr.TabItem("Upload Documents"): + with gr.Column(): + file_output = gr.Textbox(label="Upload Status") + upload_button = gr.UploadButton( + "Click to Upload Files", + file_types=[".pdf", ".docx", ".txt", ".md", ".xlsx"], + file_count="multiple" + ) + upload_button.upload(handle_file_upload, upload_button, file_output) + + with gr.TabItem("Visualize Knowledge"): + visualize_btn = gr.Button("Generate Vector Visualization") + plot_output = gr.Plot(label="Vector Space Visualization") + visualize_btn.click(kb.visualize_vectors, inputs=None, outputs=plot_output) + + return app + +def main(): + """ + Main function to initialize and run the knowledge worker + """ + global kb, chat_system + + print("=" * 60) + print("Initializing Knowledge Worker...") + print("=" * 60) + + try: + # Initialize the knowledge base + print("Setting up vector database...") + kb = KnowledgeBase(DB_NAME) + print("Vector database initialized successfully") + + # Google Drive integration removed + + # Initialize the chat system + print("\nSetting up chat system...") + chat_system = ChatSystem(kb) + print("Chat system initialized successfully") + + # Launch the Gradio app + print("\nLaunching Gradio interface...") + print("=" * 60) + print("The web interface will open in your browser") + print("You can also access it at the URL shown below") + print("=" * 60) + + app = create_ui() + app.launch(inbrowser=True) + + except Exception as e: + print(f"Error initializing Knowledge Worker: {e}") + print("Please check your configuration and try again.") + return + +if __name__ == "__main__": + main() diff --git a/week5/community-contributions/week5_jom/Exercise_week5_jom.ipynb b/week5/community-contributions/week5_jom/Exercise_week5_jom.ipynb new file mode 100644 index 0000000..8881804 --- /dev/null +++ b/week5/community-contributions/week5_jom/Exercise_week5_jom.ipynb @@ -0,0 +1,623 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "6f0f38e7", + "metadata": {}, + "source": [ + "# Email Mindmap Demo (Week 5 Community Contribution)\n", + "\n", + "Welcome to the **Email Mindmap Demo** notebook! This demo walks you through a workflow for exploring and visualizing email relationships using embeddings and mindmaps.\n", + "\n", + "---\n", + "\n", + "## 📋 Workflow Overview\n", + "\n", + "1. **Load/Create Synthetic Email Data** \n", + " Generate or load varied types of emails: work, personal, family, subscriptions, etc.\n", + "\n", + "2. **Generate Embeddings** \n", + " Use an open-source model to create vector embeddings for email content.\n", + "\n", + "3. **Build & Visualize a Mindmap** \n", + " Construct a mindmap of email relationships and visualize it interactively using `networkx` and `matplotlib`.\n", + "\n", + "4. **Question-Answering Interface** \n", + " Query the email content and the mindmap using a simple Q&A interface powered by Gradio.\n", + "\n", + "---\n", + "\n", + "## ⚙️ Requirements\n", + "\n", + "> **Tip:** \n", + "> I'm including an example of the synthetic emails in case you don't want to run that part.\n", + "> Might need to install other libraries like pyvis, nbformat and faiss-cpu\n", + "\n", + "\n", + "## ✨ Features\n", + "\n", + "- Synthetic generation of varied emails (work, personal, family, subscriptions)\n", + "- Embedding generation with open-source models (hugging face sentence-transformer)\n", + "- Interactive mindmap visualization (`networkx`, `pyvis`)\n", + "- Simple chatbot interface (Gradio) and visualization of mindmap created\n", + "\n", + "---\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "a9aeb363", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "OpenAI API Key exists and begins sk-proj-\n", + "Anthropic API Key exists and begins sk-ant-\n", + "Google API Key exists and begins AI\n", + "OLLAMA API Key exists and begins 36\n" + ] + } + ], + "source": [ + "# imports\n", + "\n", + "import os\n", + "from dotenv import load_dotenv\n", + "from openai import OpenAI\n", + "import gradio as gr\n", + "\n", + "load_dotenv(override=True)\n", + "openai_api_key = os.getenv('OPENAI_API_KEY')\n", + "anthropic_api_key = os.getenv('ANTHROPIC_API_KEY')\n", + "google_api_key = os.getenv('GOOGLE_API_KEY')\n", + "ollama_api_key = os.getenv('OLLAMA_API_KEY')\n", + "\n", + "if openai_api_key:\n", + " print(f\"OpenAI API Key exists and begins {openai_api_key[:8]}\")\n", + "else:\n", + " print(\"OpenAI API Key not set\")\n", + " \n", + "if anthropic_api_key:\n", + " print(f\"Anthropic API Key exists and begins {anthropic_api_key[:7]}\")\n", + "else:\n", + " print(\"Anthropic API Key not set (and this is optional)\")\n", + "\n", + "if google_api_key:\n", + " print(f\"Google API Key exists and begins {google_api_key[:2]}\")\n", + "else:\n", + " print(\"Google API Key not set (and this is optional)\")\n", + "\n", + "if ollama_api_key:\n", + " print(f\"OLLAMA API Key exists and begins {ollama_api_key[:2]}\")\n", + "else:\n", + " print(\"OLLAMA API Key not set (and this is optional)\")\n", + "\n", + "# Connect to client libraries\n", + "\n", + "openai = OpenAI()\n", + "\n", + "anthropic_url = \"https://api.anthropic.com/v1/\"\n", + "gemini_url = \"https://generativelanguage.googleapis.com/v1beta/openai/\"\n", + "ollama_url = \"http://localhost:11434/v1\"\n", + "\n", + "anthropic = OpenAI(api_key=anthropic_api_key, base_url=anthropic_url)\n", + "gemini = OpenAI(api_key=google_api_key, base_url=gemini_url)\n", + "ollama = OpenAI(api_key=ollama_api_key, base_url=ollama_url)\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "id": "b8ddce62", + "metadata": {}, + "source": [ + "## Preparation of synthetic data (could have been week2 work)" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "2e250912", + "metadata": {}, + "outputs": [], + "source": [ + "#using ollama gpt oss 120b cloud i'm going to create synthetic emails using a persona.\n", + "#they are going to be saved in a json file with different keys\n", + "from pydantic import BaseModel, Field\n", + "from typing import List, Optional\n", + "\n", + "\n", + "class Email(BaseModel):\n", + " sender: str = Field(description=\"Email address of the sender\")\n", + " subject: str = Field(description=\"Email subject line\")\n", + " body: str = Field(description=\"Email body content\")\n", + " timestamp: str = Field(description=\"ISO 8601 timestamp when email was received\")\n", + " category: str = Field(description=\"Category of the email\")\n", + "\n", + "class EmailBatch(BaseModel):\n", + " emails: List[Email] = Field(description=\"List of generated emails\")\n" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "1f67fdb3", + "metadata": {}, + "outputs": [], + "source": [ + "def create_persona(name: str, age: int, occupation: str, \n", + " interests: List[str], family_status: str) -> str:\n", + " persona = f\"\"\"\n", + " You are generating synthetic emails for a realistic inbox simulation.\n", + "\n", + " **Person Profile:**\n", + " - Name: {name}\n", + " - Age: {age}\n", + " - Occupation: {occupation}\n", + " - Interests: {', '.join(interests)}\n", + " - Family Status: {family_status}\n", + "\n", + " **Email Categories to Include:**\n", + " 1. **Work Emails**: Project updates, meeting invitations, colleague communications, \n", + " performance reviews, company announcements\n", + " 2. **Purchases**: Order confirmations, shipping notifications, delivery updates, \n", + " receipts from various retailers (Amazon, local shops, etc.)\n", + " 3. **Subscriptions**: Newsletter updates, streaming services (Netflix, Spotify), \n", + " software subscriptions (Adobe, Microsoft 365), magazine subscriptions\n", + " 4. **Family**: Communications with parents, siblings, children, extended family members,\n", + " family event planning, photo sharing\n", + " 5. **Friends**: Social plans, birthday wishes, casual conversations, group hangouts,\n", + " catching up messages\n", + " 6. **Finance**: Bank statements, credit card bills, investment updates, tax documents,\n", + " payment reminders\n", + " 7. **Social Media**: Facebook notifications, LinkedIn updates, Instagram activity,\n", + " Twitter mentions\n", + " 8. **Personal**: Doctor appointments, gym memberships, utility bills, insurance updates\n", + "\n", + " **Instructions:**\n", + " - Generate realistic email content that reflects the person's life over time\n", + " - Include temporal patterns (more work emails on weekdays, more personal on weekends)\n", + " - Create realistic sender names and email addresses\n", + " - Vary email length and formality based on context\n", + " - Include realistic subject lines\n", + " - Make emails interconnected when appropriate (e.g., follow-up emails, conversation threads)\n", + " - Include seasonal events (holidays, birthdays, annual renewals)\n", + " \"\"\"\n", + " return persona\n", + "\n", + "persona_description = create_persona(\n", + " name=\"John Doe\",\n", + " age=30,\n", + " occupation=\"Software Engineer\",\n", + " interests=[\"technology\", \"reading\", \"traveling\"],\n", + " family_status=\"single\"\n", + ")\n" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "cec185e3", + "metadata": {}, + "outputs": [], + "source": [ + "from openai import OpenAI\n", + "from datetime import datetime, timedelta\n", + "import random\n", + "from typing import List\n", + "\n", + "def generate_synthetic_emails(\n", + " persona_description: str,\n", + " num_emails: int,\n", + " start_date: str,\n", + " end_date: str,\n", + " model: str = \"gpt-4o-2024-08-06\"\n", + ") -> List[Email]:\n", + " \"\"\"\n", + " NEEDS TO WORK WITH OPENAI MODELS BECAUSE OF PARSED (STRUC OUTPUT) MODELS\n", + " Generates synthetic emails using OpenAI's structured output feature.\n", + " \n", + " Args:\n", + " persona_description: Detailed persona description\n", + " num_emails: Number of emails to generate per batch\n", + " start_date: Start date for email timestamps\n", + " end_date: End date for email timestamps\n", + " model: OpenAI model to use (must support structured outputs)\n", + " \n", + " Returns:\n", + " List of Email objects\n", + " \"\"\"\n", + " \n", + " # Calculate date range for context\n", + " date_range_context = f\"\"\"\n", + " Generate emails with timestamps between {start_date} and {end_date}.\n", + " Distribute emails naturally across this time period, with realistic patterns:\n", + " - More emails during business hours on weekdays\n", + " - Fewer emails late at night\n", + " - Occasional weekend emails\n", + " - Bursts of activity around events or busy periods\n", + " \"\"\"\n", + " \n", + " # System message combining persona and structure instructions\n", + " system_message = f\"\"\"\n", + " {persona_description}\n", + "\n", + " {date_range_context}\n", + "\n", + " Generate {num_emails} realistic emails that fit this person's life. \n", + " Ensure variety in categories, senders, and content while maintaining realism.\n", + " \"\"\"\n", + " \n", + " try:\n", + " client = OpenAI()\n", + "\n", + " response = client.chat.completions.parse(\n", + " model=model,\n", + " messages=[\n", + " {\n", + " \"role\": \"system\",\n", + " \"content\": system_message\n", + " },\n", + " {\n", + " \"role\": \"user\",\n", + " \"content\": f\"Generate {num_emails} diverse, realistic emails for this person's inbox.\"\n", + " }\n", + " ],\n", + " response_format=EmailBatch,\n", + " )\n", + " return response.choices[0].message.parsed.emails\n", + " \n", + " except Exception as e:\n", + " print(f\"Error generating emails: {e}\")\n", + " return []\n", + "\n", + "\n", + "def save_emails_to_json(emails: List[Email], filename: str):\n", + " \"\"\"\n", + " Saves emails to a JSON file.\n", + " \"\"\"\n", + " import json\n", + " \n", + " emails_dict = [email.model_dump() for email in emails]\n", + " \n", + " with open(filename, 'w', encoding='utf-8') as f:\n", + " json.dump(emails_dict, f, indent=2, ensure_ascii=False)\n", + " \n", + " print(f\"Saved {len(emails)} emails to {filename}\")\n" + ] + }, + { + "cell_type": "code", + "execution_count": 51, + "id": "be31f352", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "now\n" + ] + } + ], + "source": [ + "mails_2 = generate_synthetic_emails(\n", + " persona_description = persona_description,\n", + " num_emails = 100,\n", + " start_date = '2024-06-01',\n", + " end_date = '2025-01-01',\n", + " model = \"gpt-4o\"\n", + " )" + ] + }, + { + "cell_type": "code", + "execution_count": 52, + "id": "24d844f2", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Saved 101 emails to emails2.json\n" + ] + } + ], + "source": [ + "save_emails_to_json(mails_2, 'emails2.json')" + ] + }, + { + "cell_type": "markdown", + "id": "2b9c704e", + "metadata": {}, + "source": [ + "## Create embeddings for the mails\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "777012f8", + "metadata": {}, + "outputs": [], + "source": [ + "# imports for langchain, plotly and Chroma\n", + "\n", + "from langchain.document_loaders import DirectoryLoader, TextLoader\n", + "from langchain.text_splitter import CharacterTextSplitter\n", + "from langchain.schema import Document\n", + "from langchain_openai import OpenAIEmbeddings, ChatOpenAI\n", + "from langchain_chroma import Chroma\n", + "import matplotlib.pyplot as plt\n", + "from sklearn.manifold import TSNE\n", + "import numpy as np\n", + "import plotly.graph_objects as go\n", + "from langchain.memory import ConversationBufferMemory\n", + "from langchain.chains import ConversationalRetrievalChain\n", + "from langchain.embeddings import HuggingFaceEmbeddings\n", + "import json\n", + "from langchain.vectorstores import FAISS\n", + "\n", + "#MODEL = \"gpt-4o-mini\"\n", + "db_name = \"vector_db\"" + ] + }, + { + "cell_type": "code", + "execution_count": 38, + "id": "ce95d9c7", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Total number of chunks: 206\n", + "Sample metadata fields: ['sender', 'timestamp', 'category']\n" + ] + } + ], + "source": [ + "# Read in emails from the emails.json file and construct LangChain documents\n", + "\n", + "\n", + "with open(\"emails.json\", \"r\", encoding=\"utf-8\") as f:\n", + " emails = json.load(f)\n", + "\n", + "documents = []\n", + "for email in emails:\n", + " # Extract metadata (all fields except 'content')\n", + " metadata = {k: v for k, v in email.items() if k in ['sender','category','timestamp']}\n", + " body = email.get(\"body\", \"\")\n", + " documents.append(Document(page_content=body, metadata=metadata))\n", + "\n", + "text_splitter = CharacterTextSplitter(chunk_size=500, chunk_overlap=100)\n", + "chunks = text_splitter.split_documents(documents)\n", + "\n", + "print(f\"Total number of chunks: {len(chunks)}\")\n", + "print(f\"Sample metadata fields: {list(documents[0].metadata.keys()) if documents else []}\")\n", + "\n", + "embeddings_model = HuggingFaceEmbeddings(model_name=\"sentence-transformers/all-MiniLM-L6-v2\")\n", + "\n", + "if os.path.exists(db_name):\n", + " Chroma(persist_directory=db_name, embedding_function=embeddings_model).delete_collection()\n", + "\n", + "vectorstore = FAISS.from_documents(chunks, embedding=embeddings_model)\n", + "\n", + "all_embeddings = [vectorstore.index.reconstruct(i) for i in range(vectorstore.index.ntotal)]\n", + "\n", + "total_vectors = vectorstore.index.ntotal\n", + "dimensions = vectorstore.index.d\n" + ] + }, + { + "cell_type": "markdown", + "id": "78ca65bb", + "metadata": {}, + "source": [ + "## Visualizing mindmap" + ] + }, + { + "cell_type": "code", + "execution_count": 44, + "id": "a99dd2d6", + "metadata": {}, + "outputs": [], + "source": [ + "import networkx as nx\n", + "import matplotlib.pyplot as plt\n", + "from sklearn.metrics.pairwise import cosine_similarity\n", + "import plotly.graph_objects as go\n", + "import numpy as np\n", + "from sklearn.cluster import KMeans\n", + "from sklearn.manifold import TSNE # Or use UMAP\n", + "from pyvis.network import Network\n", + "\n", + "# Here, emails is your list of email objects, with .subject or .body\n", + "\n", + "# Build similarity graph\n", + "def build_mindmap_html(emails, all_embeddings, threshold=0.6):\n", + " similarity = cosine_similarity(all_embeddings)\n", + "\n", + " G = nx.Graph()\n", + " for i, email in enumerate(emails):\n", + " G.add_node(i, label=email['subject'][:80], title=email['body'][:50]) # Custom hover text\n", + "\n", + " for i in range(len(emails)):\n", + " for j in range(i+1, len(emails)):\n", + " if similarity[i][j] > threshold:\n", + " G.add_edge(i, j, weight=float(similarity[i][j]))\n", + "\n", + " # Convert to pyvis network\n", + " nt = Network(notebook=True, height='700px', width='100%', bgcolor='#222222', font_color='white')\n", + " nt.from_nx(G)\n", + " html = nt.generate_html().replace(\"'\", \"\\\"\")\n", + " return html\n" + ] + }, + { + "cell_type": "markdown", + "id": "53a2fbaf", + "metadata": {}, + "source": [ + "## Putting it all together in a gradio.\n", + "It needs to have an interface to make questions, and the visual to see the mindmap.\n" + ] + }, + { + "cell_type": "code", + "execution_count": 45, + "id": "161144ac", + "metadata": {}, + "outputs": [], + "source": [ + "# create a new Chat with OpenAI\n", + "MODEL=\"gpt-4o-mini\"\n", + "llm = ChatOpenAI(temperature=0.7, model_name=MODEL)\n", + "\n", + "# set up the conversation memory for the chat\n", + "memory = ConversationBufferMemory(memory_key='chat_history', return_messages=True)\n", + "\n", + "# the retriever is an abstraction over the VectorStore that will be used during RAG\n", + "retriever = vectorstore.as_retriever()\n", + "from langchain_core.callbacks import StdOutCallbackHandler\n", + "\n", + "# putting it together: set up the conversation chain with the GPT 3.5 LLM, the vector store and memory\n", + "conversation_chain_debug = ConversationalRetrievalChain.from_llm(llm=llm, retriever=retriever, memory=memory, callbacks=[StdOutCallbackHandler()])\n", + "conversation_chain = ConversationalRetrievalChain.from_llm(llm=llm, retriever=retriever, memory=memory)\n", + "\n", + "# Wrapping that in a function\n", + "\n", + "def chat(question, history):\n", + " result = conversation_chain.invoke({\"question\": question})\n", + " return result[\"answer\"]" + ] + }, + { + "cell_type": "code", + "execution_count": 60, + "id": "16a4d8d1", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "c:\\Users\\Javi\\Desktop\\course\\llm_engineering\\.venv\\Lib\\site-packages\\gradio\\chat_interface.py:347: UserWarning:\n", + "\n", + "The 'tuples' format for chatbot messages is deprecated and will be removed in a future version of Gradio. Please set type='messages' instead, which uses openai-style 'role' and 'content' keys.\n", + "\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Warning: When cdn_resources is 'local' jupyter notebook has issues displaying graphics on chrome/safari. Use cdn_resources='in_line' or cdn_resources='remote' if you have issues viewing graphics in a notebook.\n", + "* Running on local URL: http://127.0.0.1:7878\n", + "* To create a public link, set `share=True` in `launch()`.\n" + ] + }, + { + "data": { + "text/html": [ + "
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [] + }, + "execution_count": 60, + "metadata": {}, + "output_type": "execute_result" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Warning: When cdn_resources is 'local' jupyter notebook has issues displaying graphics on chrome/safari. Use cdn_resources='in_line' or cdn_resources='remote' if you have issues viewing graphics in a notebook.\n", + "Warning: When cdn_resources is 'local' jupyter notebook has issues displaying graphics on chrome/safari. Use cdn_resources='in_line' or cdn_resources='remote' if you have issues viewing graphics in a notebook.\n" + ] + } + ], + "source": [ + "\n", + "import gradio as gr\n", + "\n", + "def show_mindmap():\n", + " # Call build_mindmap_html to generate the HTML\n", + " html = build_mindmap_html(emails, all_embeddings)\n", + " return f\"\"\"\"\"\"\n", + "\n", + "\n", + "with gr.Blocks(title=\"Mindmap & Email Chatbot\") as demo:\n", + " gr.Markdown(\"# 📧 Mindmap Visualization & Email QA Chatbot\")\n", + " with gr.Row():\n", + " chatbot = gr.ChatInterface(fn=chat, title=\"Ask about your emails\",\n", + " examples=[\n", + " \"What is my most important message?\",\n", + " \"Who have I been communicating with?\",\n", + " \"Summarize recent emails\"\n", + " ],\n", + ")\n", + " mindmap_html = gr.HTML(\n", + " show_mindmap,\n", + " label=\"🧠 Mindmap of Your Emails\",\n", + " )\n", + " # Reduce height: update show_mindmap (elsewhere) to ~400px, or do inline replace for the demo here:\n", + " # mindmap_html = gr.HTML(lambda: show_mindmap().replace(\"height: 600px\", \"height: 400px\"))\n", + " \n", + "demo.launch(inbrowser=True)\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "221a9d98", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": ".venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.8" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}