From 7ad790be3d6d4d0f459b85e768f765b5ee6e10ef Mon Sep 17 00:00:00 2001 From: Ransford Okpoti Date: Mon, 27 Oct 2025 11:31:38 +0000 Subject: [PATCH] cohorent data generator --- ...skills-week3-coherent-data-generator.ipynb | 751 ++++++++++++++++++ 1 file changed, 751 insertions(+) create mode 100644 week3/community-contributions/ranskills-week3-coherent-data-generator.ipynb diff --git a/week3/community-contributions/ranskills-week3-coherent-data-generator.ipynb b/week3/community-contributions/ranskills-week3-coherent-data-generator.ipynb new file mode 100644 index 0000000..6be1ba4 --- /dev/null +++ b/week3/community-contributions/ranskills-week3-coherent-data-generator.ipynb @@ -0,0 +1,751 @@ +{ + "nbformat": 4, + "nbformat_minor": 0, + "metadata": { + "colab": { + "provenance": [], + "collapsed_sections": [ + "tqSpfJGnme7y" + ], + "gpuType": "T4" + }, + "kernelspec": { + "name": "python3", + "display_name": "Python 3" + }, + "language_info": { + "name": "python" + }, + "accelerator": "GPU" + }, + "cells": [ + { + "cell_type": "markdown", + "source": [ + "# ✨ Coherent Data Generator\n", + "\n", + "## In real life, data has meaning, relationships, etc. and this is where this tool shines.\n", + "\n", + "Dependencies between fields are detected and a coherent data is generated.\n", + "Example:\n", + "When asked to generate data with **Ghana** cited as the context, fields like `name`, `food`, etc. will be Ghanaian. Fields such as phone number will have the appropriate prefix of `+233`, etc.\n", + "\n", + "This is better than Faker.\n", + "\n", + "## Steps\n", + "Schema -> Generate Data\n", + "\n", + "Schema Sources:\n", + "- Use the guided schema builder\n", + "- Bring your own schema from an SQL Data Definition Language (DDL)\n", + "- Prompting\n", + "- Providing a domain to an old-hat to definition features for a dataset" + ], + "metadata": { + "id": "KbMea_UrO3Ke" + } + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "cN8z-QNlFtYc" + }, + "outputs": [], + "source": [ + "import json\n", + "\n", + "from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig\n", + "import torch\n", + "import pandas as pd\n", + "\n", + "from pydantic import BaseModel, Field\n", + "from IPython.display import display, Markdown" + ] + }, + { + "cell_type": "code", + "source": [ + "model_id = \"Qwen/Qwen3-4B-Instruct-2507\"\n", + "\n", + "device = torch.accelerator.current_accelerator().type if torch.accelerator.is_available() else 'cpu'\n", + "print(f'Device: {device}')\n", + "\n", + "tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)\n", + "\n", + "model = AutoModelForCausalLM.from_pretrained(\n", + " model_id,\n", + " dtype=\"auto\",\n", + " device_map=\"auto\"\n", + ")" + ], + "metadata": { + "id": "DOBBN3P2GD2O" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "## Schema Definitions" + ], + "metadata": { + "id": "HSUebXa1O3MM" + } + }, + { + "cell_type": "code", + "source": [ + "# This is for future use where errors in SQL DDL statements can be fixed if the\n", + "# specifies that from the UI\n", + "class SQLValidationResult(BaseModel):\n", + " is_valid: bool\n", + " is_fixable: bool\n", + " reason: str = Field(default='', description='validation failure reason')\n", + "\n", + "\n", + "class FieldDescriptor(BaseModel):\n", + " name: str = Field(..., description='Name of the field')\n", + " data_type: str = Field(..., description='Type of the field')\n", + " nullable: bool\n", + " description: str = Field(..., description='Description of the field')\n", + "\n", + "\n", + "class Schema(BaseModel):\n", + " name: str = Field(..., description='Name of the schema')\n", + " fields: list[FieldDescriptor] = Field(..., description='List of fields in the schema')" + ], + "metadata": { + "id": "5LNM76OQjAw6" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "## LLM Interactions" + ], + "metadata": { + "id": "6QjitfTBPa1E" + } + }, + { + "cell_type": "markdown", + "source": [ + "### Generate Content from LLM" + ], + "metadata": { + "id": "dXiRHok7Peir" + } + }, + { + "cell_type": "code", + "source": [ + "def generate(messages: list[dict[str, str]], temperature: float = 0.1) -> any:\n", + " text = tokenizer.apply_chat_template(\n", + " messages,\n", + " tokenize=False,\n", + " add_generation_prompt=True,\n", + " )\n", + " model_inputs = tokenizer([text], return_tensors=\"pt\").to(model.device)\n", + "\n", + " generated_ids = model.generate(\n", + " **model_inputs,\n", + " max_new_tokens=16384,\n", + " temperature=temperature\n", + " )\n", + "\n", + " output_ids = generated_ids[0][len(model_inputs.input_ids[0]):].tolist()\n", + " content = tokenizer.decode(output_ids, skip_special_tokens=True)\n", + "\n", + " return content" + ], + "metadata": { + "id": "daTUVG8_PmvM" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "### Generate Data Given A Valid Schema" + ], + "metadata": { + "id": "sBHJKn8qQhM5" + } + }, + { + "cell_type": "code", + "source": [ + "def generate_data(schema: str, context: str = '', num_records: int = 5):\n", + " system_prompt = f'''\n", + " You are synthetic data generator, you generate data based on the given schema\n", + " specific JSON structure.\n", + " When a context is provided, intelligently use that to drive the field generation.\n", + "\n", + " Example:\n", + " If Africa is given at the context, fields like name, first_name, last_name, etc.\n", + " that can be derived from Africa will be generated.\n", + "\n", + " If no context is provided, generate data randomly.\n", + "\n", + " Output an array of JSON objects.\n", + " '''\n", + "\n", + " prompt = f'''\n", + " Generate {num_records}:\n", + "\n", + " Schema:\n", + " {schema}\n", + "\n", + " Context:\n", + " {context}\n", + " '''\n", + "\n", + " messages = [\n", + " {'role': 'system', 'content': system_prompt},\n", + " {\"role\": \"user\", \"content\": prompt}\n", + " ]\n", + "\n", + " return generate(messages)" + ], + "metadata": { + "id": "Fla8UQf4Qm5l" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "### SQL" + ], + "metadata": { + "id": "izrClU6VPsZp" + } + }, + { + "cell_type": "code", + "source": [ + "def sql_validator(ddl: str):\n", + " system_prompt = '''\n", + " You are an SQL validator, your task is to validate if the given SQL is valid or not.\n", + " ONLY return a binary response of 1 and 0. Where 1=valid and 0 = not valid.\n", + " '''\n", + " prompt = f'Validate: {ddl}'\n", + "\n", + " messages = [\n", + " {'role': 'system', 'content': system_prompt},\n", + " {\"role\": \"user\", \"content\": prompt}\n", + " ]\n", + "\n", + " return generate(messages)\n", + "\n", + "\n", + "# Future work, this will fix any errors in the SQL DDL statement provided it is\n", + "# fixable.\n", + "def sql_fixer(ddl: str):\n", + " pass\n", + "\n", + "\n", + "def parse_ddl(ddl: str):\n", + " system_prompt = f'''\n", + " You are an SQL analyzer, your task is to extract column information to a\n", + " specific JSON structure.\n", + "\n", + " The output must comform to the following JSON schema:\n", + " {Schema.model_json_schema()}\n", + " '''\n", + " prompt = f'Generate schema for: {ddl}'\n", + "\n", + " messages = [\n", + " {'role': 'system', 'content': system_prompt},\n", + " {\"role\": \"user\", \"content\": prompt}\n", + " ]\n", + "\n", + " return generate(messages)" + ], + "metadata": { + "id": "aQgY6EK0QPPd" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "### Data Scientist\n", + "\n", + "Just give it a domain and you will be amazed the features will give you." + ], + "metadata": { + "id": "4mgwDQyDQ1wv" + } + }, + { + "cell_type": "code", + "source": [ + "def create_domain_schema(domain: str):\n", + " system_prompt = f'''\n", + " You are an expert Data Scientist tasked to describe features for a dataset\n", + " aspiring data scientists in a chosen domain.\n", + "\n", + " Follow these steps EXACTLY:\n", + " **Define 6–10 features** for the given domain. Include:\n", + " - At least 2 numerical features\n", + " - At least 2 categorical features\n", + " - 1 boolean or binary feature\n", + " - 1 timestamp or date feature\n", + " - Realistic dependencies (e.g., \"if loan_amount > 50000, credit_score should be high\")\n", + "\n", + " Populate your response into the JSON schema below. Strictly out **JSON**\n", + " {Schema.model_json_schema()}\n", + " '''\n", + " prompt = f'Describe the data point. Domain: {domain}'\n", + "\n", + " messages = [\n", + " {'role': 'system', 'content': system_prompt},\n", + " {\"role\": \"user\", \"content\": prompt}\n", + " ]\n", + "\n", + " return generate(messages)\n", + "\n", + "\n", + "# TODO: Use Gradion Examples to make it easier for the loading of different statements\n", + "sql = '''\n", + "CREATE TABLE users (\n", + " id BIGINT PRIMARY KEY,\n", + " name VARCHAR(100) NOT NULL,\n", + " email TEXT,\n", + " gender ENUM('F', 'M'),\n", + " country VARCHAR(100),\n", + " mobile_number VARCHAR(100),\n", + " created_at TIMESTAMP DEFAULT NOW()\n", + ");\n", + "'''" + ], + "metadata": { + "id": "P36AMvBq8AST" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "print(f'{model.get_memory_footprint() / 1e9:, .2f} GB')" + ], + "metadata": { + "id": "QuVyHOhjDtSH" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "## Export Functions" + ], + "metadata": { + "id": "tqSpfJGnme7y" + } + }, + { + "cell_type": "code", + "source": [ + "from enum import StrEnum\n", + "\n", + "\n", + "class ExportFormat(StrEnum):\n", + " CSV = 'CSV'\n", + " JSON = 'JSON'\n", + " Excel = 'Excel'\n", + " Parquet = 'Parquet'\n", + " TSV = 'TSV'\n", + " HTML = 'HTML'\n", + " Markdown = 'Markdown'\n", + " SQL = 'SQL'\n", + "\n", + "\n", + "def export_data(df, format_type):\n", + " if df is None or df.empty:\n", + " return None\n", + "\n", + " try:\n", + " if format_type == ExportFormat.CSV:\n", + " output = io.StringIO()\n", + " df.to_csv(output, index=False)\n", + " return output.getvalue()\n", + "\n", + " elif format_type == ExportFormat.JSON:\n", + " return df.to_json(orient='records', indent=2)\n", + "\n", + " elif format_type == ExportFormat.Excel:\n", + " output = io.BytesIO()\n", + " df.to_excel(output, index=False, engine='openpyxl')\n", + " return output.getvalue()\n", + "\n", + " elif format_type == ExportFormat.Parquet:\n", + " output = io.BytesIO()\n", + " df.to_parquet(output, index=False)\n", + " return output.getvalue()\n", + "\n", + " elif format_type == ExportFormat.TSV:\n", + " output = io.StringIO()\n", + " df.to_csv(output, sep='\\t', index=False)\n", + " return output.getvalue()\n", + "\n", + " elif format_type == ExportFormat.HTML:\n", + " return df.to_html(index=False)\n", + "\n", + " elif format_type == ExportFormat.Markdown:\n", + " return df.to_markdown(index=False)\n", + "\n", + " elif format_type == ExportFormat.SQL:\n", + " from sqlalchemy import create_engine\n", + " engine = create_engine('sqlite:///:memory:')\n", + " table = 'users' # TODO: fix this\n", + "\n", + " df.to_sql(table, con=engine, index=False)\n", + " connection = engine.raw_connection()\n", + " sql_statements = list(connection.iterdump())\n", + " sql_output_string = \"\\n\".join(sql_statements)\n", + " connection.close()\n", + "\n", + " return sql_output_string\n", + "\n", + " except Exception as e:\n", + " print(f\"Export error: {str(e)}\")\n", + " return None\n", + "\n", + "\n", + "def prepare_download(df, format_type):\n", + " if df is None:\n", + " return None\n", + "\n", + " content = export_data(df, format_type)\n", + " if content is None:\n", + " return None\n", + "\n", + " extensions = {\n", + " ExportFormat.CSV: '.csv',\n", + " ExportFormat.JSON: '.json',\n", + " ExportFormat.Excel: '.xlsx',\n", + " ExportFormat.Parquet: '.parquet',\n", + " ExportFormat.TSV: '.tsv',\n", + " ExportFormat.HTML: '.html',\n", + " ExportFormat.Markdown: '.md',\n", + " ExportFormat.SQL: '.sql',\n", + " }\n", + "\n", + " filename = f'generated_data{extensions.get(format_type, \".txt\")}'\n", + "\n", + " is_binary_format = format_type in [ExportFormat.Excel, ExportFormat.Parquet]\n", + " mode = 'w+b' if is_binary_format else 'w'\n", + "\n", + " import tempfile\n", + " with tempfile.NamedTemporaryFile(mode=mode, delete=False, suffix=extensions[format_type]) as tmp:\n", + " tmp.write(content)\n", + " tmp.flush()\n", + " return tmp.name" + ], + "metadata": { + "id": "pAu5OPfUmMSm" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "## Gradio UI" + ], + "metadata": { + "id": "Q0fZsCuso_YZ" + } + }, + { + "cell_type": "code", + "source": [ + "import gradio as gr\n", + "from pydantic import BaseModel, Field\n", + "import json\n", + "import pandas as pd\n", + "import io\n", + "\n", + "DATA_TYPES = ['string', 'integer', 'float', 'boolean', 'date', 'datetime', 'array', 'object']\n", + "\n", + "def generate_from_sql(sql: str, context: str, num_records: int = 10):\n", + " try:\n", + " print(f'SQL: {sql}')\n", + " schema = parse_ddl(sql)\n", + " data = generate_data(schema, context, num_records)\n", + "\n", + " data = json.loads(data)\n", + " df = pd.DataFrame(data)\n", + "\n", + " return schema, df\n", + " except Exception as e:\n", + " return f'Error: {str(e)}', None\n", + "\n", + "\n", + "def generate_from_data_scientist(domain: str, context: str, num_records: int = 10):\n", + " try:\n", + " print(f'Domain: {domain}')\n", + " schema = create_domain_schema(domain)\n", + " print(schema)\n", + " data = generate_data(schema, context, num_records)\n", + " data = json.loads(data)\n", + " df = pd.DataFrame(data)\n", + "\n", + " return schema, df\n", + " except Exception as e:\n", + " return f'Error: {str(e)}', None\n", + "\n", + "\n", + "def generate_from_dynamic_fields(schema_name, context: str, num_fields, num_records: int, *field_values):\n", + " try:\n", + " fields = []\n", + " for i in range(num_fields):\n", + " idx = i * 4\n", + " if idx + 3 < len(field_values):\n", + " name = field_values[idx]\n", + " dtype = field_values[idx + 1]\n", + " nullable = field_values[idx + 2]\n", + " desc = field_values[idx + 3]\n", + "\n", + " if name and dtype:\n", + " fields.append(FieldDescriptor(\n", + " name=name,\n", + " data_type=dtype,\n", + " nullable=nullable if nullable is not None else False,\n", + " description=desc if desc else ''\n", + " ))\n", + "\n", + " if not schema_name:\n", + " return 'Error: Schema name is required', None\n", + "\n", + " if not fields:\n", + " return 'Error: At least one field is required', None\n", + "\n", + " schema = Schema(name=schema_name, fields=fields)\n", + " data = generate_data(schema.model_dump(), context , num_records)\n", + " data = json.loads(data)\n", + " df = pd.DataFrame(data)\n", + "\n", + "\n", + " return json.dumps(schema.model_dump(), indent=2), df\n", + "\n", + " except Exception as e:\n", + " return f'Error: {str(e)}', None\n", + "\n", + "\n", + "\n", + "title='✨ Coherent Data Generator'\n", + "\n", + "with gr.Blocks(title=title, theme=gr.themes.Monochrome()) as ui:\n", + " gr.Markdown(f'# {title}')\n", + " gr.Markdown('Embrass the Coherent Data wins 🏆!')\n", + "\n", + " df_state = gr.State(value=None)\n", + "\n", + " with gr.Row():\n", + " num_records_input = gr.Number(\n", + " label='Number of Records to Generate',\n", + " value=10,\n", + " minimum=1,\n", + " maximum=10000,\n", + " step=1,\n", + " precision=0\n", + " )\n", + "\n", + " context_input = gr.Textbox(\n", + " label='Context',\n", + " placeholder='70% Ghana and 30% Nigeria data. Start ID generation from 200',\n", + " lines=1\n", + " )\n", + "\n", + " with gr.Tabs() as tabs:\n", + " with gr.Tab('Manual Entry', id=0):\n", + " schema_name_input = gr.Textbox(label='Schema Name', placeholder='Enter schema name')\n", + "\n", + " gr.Markdown('### Fields')\n", + "\n", + " num_fields_state = gr.State(3)\n", + "\n", + " with gr.Row():\n", + " num_fields_slider = gr.Slider(\n", + " minimum=1,\n", + " maximum=20,\n", + " value=3,\n", + " step=1,\n", + " label='Number of Fields',\n", + " interactive=True\n", + " )\n", + "\n", + " gr.HTML('''\n", + "
\n", + "
Field Name
\n", + "
Data Type
\n", + "
Nullable
\n", + "
Description
\n", + "
\n", + " ''')\n", + "\n", + " field_components = []\n", + " row_components = []\n", + "\n", + " for i in range(20):\n", + " with gr.Row(visible=(i < 3)) as row:\n", + " field_name = gr.Textbox(label='', container=False, scale=2)\n", + " data_type = gr.Dropdown(choices=DATA_TYPES, value='string', label='', container=False, scale=2)\n", + " nullable = gr.Checkbox(label='', container=False, scale=1)\n", + " description = gr.Textbox(label='', container=False, scale=3)\n", + "\n", + " row_components.append(row)\n", + " field_components.extend([field_name, data_type, nullable, description])\n", + "\n", + " submit_btn = gr.Button('Generate', variant='primary')\n", + "\n", + " num_fields_slider.change(\n", + " fn=lambda x: [gr.update(visible=(i < x)) for i in range(20)],\n", + " inputs=[num_fields_slider],\n", + " outputs=row_components\n", + " )\n", + "\n", + "\n", + " with gr.Tab('SQL', id=1):\n", + " gr.Markdown('### Parse SQL DDL')\n", + " ddl_input = gr.Code(\n", + " value=sql,\n", + " label='SQL DDL Statement',\n", + " language='sql',\n", + " lines=10\n", + " )\n", + " ddl_btn = gr.Button('Generate', variant='primary')\n", + "\n", + "\n", + " with gr.Tab('>_ Prompt', id=2):\n", + " gr.Markdown('### You are on your own here, so be creative 💡')\n", + " prompt_input = gr.Textbox(\n", + " label='Prompt',\n", + " placeholder='Type your prompt',\n", + " lines=10\n", + " )\n", + " prompt_btn = gr.Button('Generate', variant='primary')\n", + "\n", + " with gr.Tab('Data Scientist 🎩', id=3):\n", + " gr.Markdown('### You are on your own here, so be creative 💡')\n", + " domain_input = gr.Dropdown(\n", + " label='Domain',\n", + " choices=['E-commerce Customers', 'Hospital Patients', 'Loan Applications'],\n", + " allow_custom_value=True\n", + " )\n", + "\n", + " data_scientist_generate_btn = gr.Button('Generate', variant='primary')\n", + "\n", + "\n", + " with gr.Accordion('Generated Schema', open=False):\n", + " output = gr.Code(label='Schema (JSON)', language='json')\n", + "\n", + " gr.Markdown('## Generated Data')\n", + " dataframe_output = gr.Dataframe(\n", + " label='',\n", + " interactive=False,\n", + " wrap=True\n", + " )\n", + "\n", + " gr.Markdown('### Export Data')\n", + " with gr.Row():\n", + " format_dropdown = gr.Dropdown(\n", + " choices=[format.value for format in ExportFormat],\n", + " value=ExportFormat.CSV,\n", + " label='Export Format',\n", + " scale=2\n", + " )\n", + " download_btn = gr.Button('Download', variant='secondary', scale=1)\n", + "\n", + " download_file = gr.File(label='Download File', visible=True)\n", + "\n", + "\n", + " def _handle_result(result):\n", + " if isinstance(result, tuple) and len(result) == 2:\n", + " return result[0], result[1], result[1]\n", + " return result[0], result[1], None\n", + "\n", + "\n", + " def update_from_dynamic_fields(schema_name, context, num_fields, num_records, *field_values):\n", + " result = generate_from_dynamic_fields(schema_name, context, num_fields, num_records, *field_values)\n", + " return _handle_result(result)\n", + "\n", + "\n", + " def update_from_sql(sql: str, context, num_records: int):\n", + " result = generate_from_sql(sql, context, num_records)\n", + " return _handle_result(result)\n", + "\n", + "\n", + " def update_from_data_scientist(domain: str, context, num_records: int):\n", + " result = generate_from_data_scientist(domain, context, num_records)\n", + " return _handle_result(result)\n", + "\n", + "\n", + " submit_btn.click(\n", + " fn=update_from_dynamic_fields,\n", + " inputs=[schema_name_input, context_input, num_fields_slider, num_records_input] + field_components,\n", + " outputs=[output, dataframe_output, df_state]\n", + " )\n", + "\n", + " ddl_btn.click(\n", + " fn=update_from_sql,\n", + " inputs=[ddl_input, context_input, num_records_input],\n", + " outputs=[output, dataframe_output, df_state]\n", + " )\n", + "\n", + " data_scientist_generate_btn.click(\n", + " fn=update_from_data_scientist,\n", + " inputs=[domain_input, context_input, num_records_input],\n", + " outputs=[output, dataframe_output, df_state]\n", + " )\n", + "\n", + "\n", + " download_btn.click(\n", + " fn=prepare_download,\n", + " inputs=[df_state, format_dropdown],\n", + " outputs=[download_file]\n", + " )\n", + "\n", + "\n", + "ui.launch(debug=True)" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "TJYUWecybDpP", + "outputId": "e82d0a13-3ca3-4a01-d45c-78fc94ade9bc" + }, + "execution_count": 10, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Keyboard interruption in main thread... closing server.\n", + "Killing tunnel 127.0.0.1:7860 <> https://5954eb89d994d7a5ee.gradio.live\n" + ] + }, + { + "output_type": "execute_result", + "data": { + "text/plain": [] + }, + "metadata": {}, + "execution_count": 10 + } + ] + } + ] +} \ No newline at end of file