{ "cells": [ { "cell_type": "code", "execution_count": null, "id": "c08309b8-13f0-45bb-a3ea-7b01f05a7346", "metadata": {}, "outputs": [], "source": [ "import os\n", "import json\n", "import pandas as pd\n", "import random\n", "import re\n", "import subprocess\n", "import pyarrow as pa\n", "from typing import List\n", "import openai\n", "import anthropic\n", "from dotenv import load_dotenv\n", "import gradio as gr" ] }, { "cell_type": "code", "execution_count": null, "id": "f5efd903-e683-4e7f-8747-2998e23a0751", "metadata": {}, "outputs": [], "source": [ "# load API\n", "load_dotenv(override=True)" ] }, { "cell_type": "code", "execution_count": null, "id": "ce49b86a-53f4-4d4f-a721-0d66d9c1b070", "metadata": {}, "outputs": [], "source": [ "# --- Schema Definition ---\n", "SCHEMA = [\n", " (\"Team\", \"TEXT\", '\"Toronto Raptors\"'),\n", " (\"NAME\", \"TEXT\", '\"Otto Porter Jr.\"'),\n", " (\"Jersey\", \"TEXT\", '\"10\", or \"NA\" if null'),\n", " (\"POS\", \"TEXT\", 'One of [\"PF\",\"SF\",\"G\",\"C\",\"SG\",\"F\",\"PG\"]'),\n", " (\"AGE\", \"INT\", 'integer age in years, e.g., 22'),\n", " (\"HT\", \"TEXT\", '`6\\' 7\"` or `6\\' 10\"`'),\n", " (\"WT\", \"TEXT\", '\"232 lbs\"'),\n", " (\"COLLEGE\", \"TEXT\", '\"Michigan\", or \"--\" if null'),\n", " (\"SALARY\", \"TEXT\", '\"$9,945,830\", or \"--\" if null')\n", "]" ] }, { "cell_type": "code", "execution_count": null, "id": "93743e57-c2c5-43e5-8fa1-2e242085db07", "metadata": {}, "outputs": [], "source": [ "# Default schema text for the textbox\n", "DEFAULT_SCHEMA_TEXT = \"\\n\".join([f\"{i+1}. {col[0]} ({col[1]}) Example: {col[2]}\" for i, col in enumerate(SCHEMA)])" ] }, { "cell_type": "code", "execution_count": null, "id": "87c58595-6fdd-48f5-a253-ccba352cb385", "metadata": {}, "outputs": [], "source": [ "# Available models\n", "MODELS = [\n", " \"gpt-4o\",\n", " \"claude-3-5-haiku-20241022\", \n", " \"ollama:llama3.2:latest\"\n", "]" ] }, { "cell_type": "code", "execution_count": null, "id": "08cd9ce2-8685-46b5-95d0-811b8025696f", "metadata": {}, "outputs": [], "source": [ "# Available file formats\n", "FILE_FORMATS = [\".csv\", \".tsv\", \".jsonl\", \".parquet\", \".arrow\"]" ] }, { "cell_type": "code", "execution_count": null, "id": "13d68c7f-6f49-4efa-b075-f1e7db2ab527", "metadata": {}, "outputs": [], "source": [ "def get_prompt(n: int, schema_text: str, system_prompt: str) -> str:\n", " prompt = f\"\"\"\n", "{system_prompt}\n", "\n", "Generate {n} rows of realistic basketball player data in JSONL format, each line a JSON object with the following fields:\n", "\n", "{schema_text}\n", "\n", "Do NOT repeat column values from one row to another.\n", "\n", "Only output valid JSONL.\n", "\"\"\"\n", " return prompt.strip()" ] }, { "cell_type": "code", "execution_count": null, "id": "cdc68f1e-4fbe-45dc-aa36-ce5f718ef6ca", "metadata": {}, "outputs": [], "source": [ "# --- LLM Interface ---\n", "def query_model(prompt: str, model: str = \"gpt-4o\") -> List[dict]:\n", " \"\"\"Call OpenAI, Claude, or Ollama\"\"\"\n", " try:\n", " if model.lower().startswith(\"gpt\"):\n", " client = openai.OpenAI(api_key=os.getenv(\"OPENAI_API_KEY\"))\n", " response = client.chat.completions.create(\n", " model=model,\n", " messages=[{\"role\": \"user\", \"content\": prompt}],\n", " temperature=0.7\n", " )\n", " content = response.choices[0].message.content\n", "\n", " elif model.lower().startswith(\"claude\"):\n", " client = anthropic.Anthropic(api_key=os.getenv(\"ANTHROPIC_API_KEY\"))\n", " response = client.messages.create(\n", " model=model,\n", " messages=[{\"role\": \"user\", \"content\": prompt}],\n", " max_tokens=4000,\n", " temperature=0.7\n", " )\n", " content = response.content[0].text\n", "\n", " elif model.lower().startswith(\"ollama:\"):\n", " ollama_model = model.split(\":\")[1]\n", " result = subprocess.run(\n", " [\"ollama\", \"run\", ollama_model],\n", " input=prompt,\n", " text=True,\n", " capture_output=True\n", " )\n", " if result.returncode != 0:\n", " raise Exception(f\"Ollama error: {result.stderr}\")\n", " content = result.stdout\n", " else:\n", " raise ValueError(\"Unsupported model. Use 'gpt-4.1-mini', 'claude-3-5-haiku-20241022', or 'ollama:llama3.2:latest'\")\n", "\n", " # Parse JSONL output\n", " lines = [line.strip() for line in content.strip().splitlines() if line.strip().startswith(\"{\")]\n", " return [json.loads(line) for line in lines]\n", " \n", " except Exception as e:\n", " raise Exception(f\"Model query failed: {str(e)}\")" ] }, { "cell_type": "code", "execution_count": null, "id": "29e3f5f5-e99c-429c-bea9-69d554c58c9c", "metadata": {}, "outputs": [], "source": [ "# --- Output Formatter ---\n", "def save_dataset(records: List[dict], file_format: str, filename: str):\n", " df = pd.DataFrame(records)\n", " if file_format == \".csv\":\n", " df.to_csv(filename, index=False)\n", " elif file_format == \".tsv\":\n", " df.to_csv(filename, sep=\"\\t\", index=False)\n", " elif file_format == \".jsonl\":\n", " with open(filename, \"w\") as f:\n", " for record in records:\n", " f.write(json.dumps(record) + \"\\n\")\n", " elif file_format == \".parquet\":\n", " df.to_parquet(filename, engine=\"pyarrow\", index=False)\n", " elif file_format == \".arrow\":\n", " table = pa.Table.from_pandas(df)\n", " with pa.OSFile(filename, \"wb\") as sink:\n", " with pa.ipc.new_file(sink, table.schema) as writer:\n", " writer.write(table)\n", " else:\n", " raise ValueError(\"Unsupported file format\")" ] }, { "cell_type": "code", "execution_count": null, "id": "fe258e84-66f4-4fe7-99c0-75b24148e147", "metadata": {}, "outputs": [], "source": [ "# --- Main Generation Function ---\n", "def generate_dataset(schema_text, system_prompt, model, nr_records, file_format, save_as):\n", " try:\n", " # Validation\n", " if nr_records <= 10:\n", " return \"āŒ Error: Nr_records must be greater than 10.\", None\n", " \n", " if file_format not in FILE_FORMATS:\n", " return \"āŒ Error: Invalid file format specified.\", None\n", " \n", " if not save_as or save_as.strip() == \"\":\n", " save_as = f\"basketball_dataset{file_format}\"\n", " elif not save_as.endswith(file_format):\n", " save_as = save_as + file_format\n", " \n", " # Generate prompt\n", " prompt = get_prompt(nr_records, schema_text, system_prompt)\n", " \n", " # Query model\n", " records = query_model(prompt, model=model)\n", " \n", " if not records:\n", " return \"āŒ Error: No valid records generated from the model.\", None\n", " \n", " # Save dataset\n", " save_dataset(records, file_format, save_as)\n", " \n", " # Create preview\n", " df = pd.DataFrame(records)\n", " preview = df.head(10) # Show first 10 rows\n", " \n", " success_message = f\"āœ… Dataset generated successfully!\\nšŸ“ Saved to: {save_as}\\nšŸ“Š Generated {len(records)} records\"\n", " \n", " return success_message, preview\n", " \n", " except Exception as e:\n", " return f\"āŒ Error: {str(e)}\", None" ] }, { "cell_type": "code", "execution_count": null, "id": "c2405a9d-b4cd-43d9-82f6-ff3512b4541f", "metadata": {}, "outputs": [], "source": [ "# --- Gradio Interface ---\n", "def create_interface():\n", " with gr.Blocks(title=\"Dataset Generator\", theme=gr.themes.Soft()) as interface:\n", " gr.Markdown(\"# Dataset Generator\")\n", " gr.Markdown(\"Generate realistic datasets using AI models\")\n", " \n", " with gr.Row():\n", " with gr.Column(scale=2):\n", " schema_input = gr.Textbox(\n", " label=\"Schema\",\n", " value=DEFAULT_SCHEMA_TEXT,\n", " lines=15,\n", " placeholder=\"Define your dataset schema here...\"\n", " )\n", " \n", " system_prompt_input = gr.Textbox(\n", " label=\"Prompt\",\n", " value=\"You are a helpful assistant that generates realistic basketball player data.\",\n", " lines=1,\n", " placeholder=\"Enter system prompt for the model...\"\n", " )\n", " \n", " with gr.Row():\n", " model_dropdown = gr.Dropdown(\n", " label=\"Model\",\n", " choices=MODELS,\n", " value=MODELS[1], # Default to Claude\n", " interactive=True\n", " )\n", " \n", " nr_records_input = gr.Number(\n", " label=\"Nr. records\",\n", " value=25,\n", " minimum=11,\n", " maximum=1000,\n", " step=1\n", " )\n", " \n", " with gr.Row():\n", " file_format_dropdown = gr.Dropdown(\n", " label=\"File format\",\n", " choices=FILE_FORMATS,\n", " value=\".csv\",\n", " interactive=True\n", " )\n", " \n", " save_as_input = gr.Textbox(\n", " label=\"Save as\",\n", " value=\"basketball_dataset\",\n", " placeholder=\"Enter filename (extension will be added automatically)\"\n", " )\n", " \n", " generate_btn = gr.Button(\"šŸš€ Generate\", variant=\"primary\", size=\"lg\")\n", " \n", " with gr.Column(scale=1):\n", " output_status = gr.Textbox(\n", " label=\"Status\",\n", " lines=4,\n", " interactive=False\n", " )\n", " \n", " output_preview = gr.Dataframe(\n", " label=\"Preview (First 10 rows)\",\n", " interactive=False,\n", " wrap=True\n", " )\n", " \n", " # Connect the generate button\n", " generate_btn.click(\n", " fn=generate_dataset,\n", " inputs=[\n", " schema_input,\n", " system_prompt_input, \n", " model_dropdown,\n", " nr_records_input,\n", " file_format_dropdown,\n", " save_as_input\n", " ],\n", " outputs=[output_status, output_preview]\n", " )\n", " \n", " gr.Markdown(\"\"\"\n", " ### šŸ“ Instructions:\n", " 1. **Schema**: Define the structure of your dataset (pre-filled with basketball player schema)\n", " 2. **Prompt**: System prompt to guide the AI model\n", " 3. **Model**: Choose between GPT, Claude, or Ollama models\n", " 4. **Nr. records**: Number of records to generate (minimum 11)\n", " 5. **File format**: Choose output format (.csv, .tsv, .jsonl, .parquet, .arrow)\n", " 6. **Save as**: Filename (extension added automatically)\n", " 7. Click **Generate** to create your dataset\n", " \n", " ### šŸ”§ Requirements:\n", " - Set up your API keys in `.env` file (`OPENAI_API_KEY`, `ANTHROPIC_API_KEY`)\n", " - For Ollama models, ensure Ollama is installed and running locally\n", " \"\"\")\n", " \n", " return interface" ] }, { "cell_type": "code", "execution_count": null, "id": "50fd2b91-2578-4224-b9dd-e28caf6a0a85", "metadata": {}, "outputs": [], "source": [ "interface = create_interface()\n", "interface.launch(inbrowser=True)" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.3" } }, "nbformat": 4, "nbformat_minor": 5 }