{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "KbMea_UrO3Ke" }, "source": [ "# ✨ Coherent Data Generator\n", "\n", "## In real life, data has meaning, relationships, etc., and this is where this tool shines.\n", "\n", "Dependencies between fields are detected, and coherent data is generated.\n", "Example:\n", "When asked to generate data with **Ghana** cited as the context, fields like `name`, `food`, etc., will be Ghanaian. Fields such as phone number will have the appropriate prefix of `+233`, etc.\n", "\n", "This is better than Faker.\n", "\n", "## Steps\n", "Schema -> Generate Data\n", "\n", "Schema Sources: \n", "- Use the guided schema builder\n", "- Bring your own schema from an SQL Data Definition Language (DDL)\n", "- Prompting\n", "- Providing a domain to an old hat to define features for a dataset" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "cN8z-QNlFtYc" }, "outputs": [], "source": [ "import json\n", "\n", "from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig\n", "import torch\n", "import pandas as pd\n", "\n", "from pydantic import BaseModel, Field\n", "from IPython.display import display, Markdown" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "DOBBN3P2GD2O" }, "outputs": [], "source": [ "model_id = \"Qwen/Qwen3-4B-Instruct-2507\"\n", "\n", "device = torch.accelerator.current_accelerator().type if torch.accelerator.is_available() else 'cpu'\n", "print(f'Device: {device}')\n", "\n", "tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)\n", "\n", "model = AutoModelForCausalLM.from_pretrained(\n", " model_id,\n", " dtype=\"auto\",\n", " device_map=\"auto\"\n", ")" ] }, { "cell_type": "markdown", "metadata": { "id": "HSUebXa1O3MM" }, "source": [ "## Schema Definitions" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "5LNM76OQjAw6" }, "outputs": [], "source": [ "# This is for future use where errors in SQL DDL statements can be fixed if the\n", "# specifies that from the UI\n", "class SQLValidationResult(BaseModel):\n", " is_valid: bool\n", " is_fixable: bool\n", " reason: str = Field(default='', description='validation failure reason')\n", "\n", "\n", "class FieldDescriptor(BaseModel):\n", " name: str = Field(..., description='Name of the field')\n", " data_type: str = Field(..., description='Type of the field')\n", " nullable: bool\n", " description: str = Field(..., description='Description of the field')\n", "\n", "\n", "class Schema(BaseModel):\n", " name: str = Field(..., description='Name of the schema')\n", " fields: list[FieldDescriptor] = Field(..., description='List of fields in the schema')" ] }, { "cell_type": "markdown", "metadata": { "id": "6QjitfTBPa1E" }, "source": [ "## LLM Interactions" ] }, { "cell_type": "markdown", "metadata": { "id": "dXiRHok7Peir" }, "source": [ "### Generate Content from LLM" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "daTUVG8_PmvM" }, "outputs": [], "source": [ "def generate(messages: list[dict[str, str]], temperature: float = 0.1) -> any:\n", " text = tokenizer.apply_chat_template(\n", " messages,\n", " tokenize=False,\n", " add_generation_prompt=True,\n", " )\n", " model_inputs = tokenizer([text], return_tensors=\"pt\").to(model.device)\n", "\n", " generated_ids = model.generate(\n", " **model_inputs,\n", " max_new_tokens=16384,\n", " temperature=temperature\n", " )\n", "\n", " output_ids = generated_ids[0][len(model_inputs.input_ids[0]):].tolist()\n", " content = tokenizer.decode(output_ids, skip_special_tokens=True)\n", "\n", " return content" ] }, { "cell_type": "markdown", "metadata": { "id": "sBHJKn8qQhM5" }, "source": [ "### Generate Data Given A Valid Schema" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Fla8UQf4Qm5l" }, "outputs": [], "source": [ "def generate_data(schema: str, context: str = '', num_records: int = 5):\n", " system_prompt = f'''\n", " You are synthetic data generator, you generate data based on the given schema\n", " specific JSON structure.\n", " When a context is provided, intelligently use that to drive the field generation.\n", "\n", " Example:\n", " If Africa is given at the context, fields like name, first_name, last_name, etc.\n", " that can be derived from Africa will be generated.\n", "\n", " If no context is provided, generate data randomly.\n", "\n", " Output an array of JSON objects.\n", " '''\n", "\n", " prompt = f'''\n", " Generate {num_records}:\n", "\n", " Schema:\n", " {schema}\n", "\n", " Context:\n", " {context}\n", " '''\n", "\n", " messages = [\n", " {'role': 'system', 'content': system_prompt},\n", " {\"role\": \"user\", \"content\": prompt}\n", " ]\n", "\n", " return generate(messages)" ] }, { "cell_type": "markdown", "metadata": { "id": "izrClU6VPsZp" }, "source": [ "### SQL" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "aQgY6EK0QPPd" }, "outputs": [], "source": [ "def sql_validator(ddl: str):\n", " system_prompt = '''\n", " You are an SQL validator, your task is to validate if the given SQL is valid or not.\n", " ONLY return a binary response of 1 and 0. Where 1=valid and 0 = not valid.\n", " '''\n", " prompt = f'Validate: {ddl}'\n", "\n", " messages = [\n", " {'role': 'system', 'content': system_prompt},\n", " {\"role\": \"user\", \"content\": prompt}\n", " ]\n", "\n", " return generate(messages)\n", "\n", "\n", "# Future work, this will fix any errors in the SQL DDL statement provided it is\n", "# fixable.\n", "def sql_fixer(ddl: str):\n", " pass\n", "\n", "\n", "def parse_ddl(ddl: str):\n", " system_prompt = f'''\n", " You are an SQL analyzer, your task is to extract column information to a\n", " specific JSON structure.\n", "\n", " The output must comform to the following JSON schema:\n", " {Schema.model_json_schema()}\n", " '''\n", " prompt = f'Generate schema for: {ddl}'\n", "\n", " messages = [\n", " {'role': 'system', 'content': system_prompt},\n", " {\"role\": \"user\", \"content\": prompt}\n", " ]\n", "\n", " return generate(messages)" ] }, { "cell_type": "markdown", "metadata": { "id": "4mgwDQyDQ1wv" }, "source": [ "### Data Scientist\n", "\n", "Just give it a domain and you will be amazed the features will give you." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "P36AMvBq8AST" }, "outputs": [], "source": [ "def create_domain_schema(domain: str):\n", " system_prompt = f'''\n", " You are an expert Data Scientist tasked to describe features for a dataset\n", " aspiring data scientists in a chosen domain.\n", "\n", " Follow these steps EXACTLY:\n", " **Define 6–10 features** for the given domain. Include:\n", " - At least 2 numerical features\n", " - At least 2 categorical features\n", " - 1 boolean or binary feature\n", " - 1 timestamp or date feature\n", " - Realistic dependencies (e.g., \"if loan_amount > 50000, credit_score should be high\")\n", "\n", " Populate your response into the JSON schema below. Strictly out **JSON**\n", " {Schema.model_json_schema()}\n", " '''\n", " prompt = f'Describe the data point. Domain: {domain}'\n", "\n", " messages = [\n", " {'role': 'system', 'content': system_prompt},\n", " {\"role\": \"user\", \"content\": prompt}\n", " ]\n", "\n", " return generate(messages)\n", "\n", "\n", "# TODO: Use Gradion Examples to make it easier for the loading of different statements\n", "sql = '''\n", "CREATE TABLE users (\n", " id BIGINT PRIMARY KEY,\n", " name VARCHAR(100) NOT NULL,\n", " email TEXT,\n", " gender ENUM('F', 'M'),\n", " country VARCHAR(100),\n", " mobile_number VARCHAR(100),\n", " created_at TIMESTAMP DEFAULT NOW()\n", ");\n", "'''" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "QuVyHOhjDtSH" }, "outputs": [], "source": [ "print(f'{model.get_memory_footprint() / 1e9:, .2f} GB')" ] }, { "cell_type": "markdown", "metadata": { "id": "tqSpfJGnme7y" }, "source": [ "## Export Functions" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "pAu5OPfUmMSm" }, "outputs": [], "source": [ "from enum import StrEnum\n", "\n", "\n", "class ExportFormat(StrEnum):\n", " CSV = 'CSV'\n", " JSON = 'JSON'\n", " Excel = 'Excel'\n", " Parquet = 'Parquet'\n", " TSV = 'TSV'\n", " HTML = 'HTML'\n", " Markdown = 'Markdown'\n", " SQL = 'SQL'\n", "\n", "\n", "def export_data(df, format_type):\n", " if df is None or df.empty:\n", " return None\n", "\n", " try:\n", " if format_type == ExportFormat.CSV:\n", " output = io.StringIO()\n", " df.to_csv(output, index=False)\n", " return output.getvalue()\n", "\n", " elif format_type == ExportFormat.JSON:\n", " return df.to_json(orient='records', indent=2)\n", "\n", " elif format_type == ExportFormat.Excel:\n", " output = io.BytesIO()\n", " df.to_excel(output, index=False, engine='openpyxl')\n", " return output.getvalue()\n", "\n", " elif format_type == ExportFormat.Parquet:\n", " output = io.BytesIO()\n", " df.to_parquet(output, index=False)\n", " return output.getvalue()\n", "\n", " elif format_type == ExportFormat.TSV:\n", " output = io.StringIO()\n", " df.to_csv(output, sep='\\t', index=False)\n", " return output.getvalue()\n", "\n", " elif format_type == ExportFormat.HTML:\n", " return df.to_html(index=False)\n", "\n", " elif format_type == ExportFormat.Markdown:\n", " return df.to_markdown(index=False)\n", "\n", " elif format_type == ExportFormat.SQL:\n", " from sqlalchemy import create_engine\n", " engine = create_engine('sqlite:///:memory:')\n", " table = 'users' # TODO: fix this\n", "\n", " df.to_sql(table, con=engine, index=False)\n", " connection = engine.raw_connection()\n", " sql_statements = list(connection.iterdump())\n", " sql_output_string = \"\\n\".join(sql_statements)\n", " connection.close()\n", "\n", " return sql_output_string\n", "\n", " except Exception as e:\n", " print(f\"Export error: {str(e)}\")\n", " return None\n", "\n", "\n", "def prepare_download(df, format_type):\n", " if df is None:\n", " return None\n", "\n", " content = export_data(df, format_type)\n", " if content is None:\n", " return None\n", "\n", " extensions = {\n", " ExportFormat.CSV: '.csv',\n", " ExportFormat.JSON: '.json',\n", " ExportFormat.Excel: '.xlsx',\n", " ExportFormat.Parquet: '.parquet',\n", " ExportFormat.TSV: '.tsv',\n", " ExportFormat.HTML: '.html',\n", " ExportFormat.Markdown: '.md',\n", " ExportFormat.SQL: '.sql',\n", " }\n", "\n", " filename = f'generated_data{extensions.get(format_type, \".txt\")}'\n", "\n", " is_binary_format = format_type in [ExportFormat.Excel, ExportFormat.Parquet]\n", " mode = 'w+b' if is_binary_format else 'w'\n", "\n", " import tempfile\n", " with tempfile.NamedTemporaryFile(mode=mode, delete=False, suffix=extensions[format_type]) as tmp:\n", " tmp.write(content)\n", " tmp.flush()\n", " return tmp.name" ] }, { "cell_type": "markdown", "metadata": { "id": "Q0fZsCuso_YZ" }, "source": [ "## Gradio UI" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "TJYUWecybDpP", "outputId": "e82d0a13-3ca3-4a01-d45c-78fc94ade9bc" }, "outputs": [], "source": [ "import gradio as gr\n", "from pydantic import BaseModel, Field\n", "import json\n", "import pandas as pd\n", "import io\n", "\n", "DATA_TYPES = ['string', 'integer', 'float', 'boolean', 'date', 'datetime', 'array', 'object']\n", "\n", "def generate_from_sql(sql: str, context: str, num_records: int = 10):\n", " try:\n", " print(f'SQL: {sql}')\n", " schema = parse_ddl(sql)\n", " data = generate_data(schema, context, num_records)\n", "\n", " data = json.loads(data)\n", " df = pd.DataFrame(data)\n", "\n", " return schema, df\n", " except Exception as e:\n", " return f'Error: {str(e)}', None\n", "\n", "\n", "def generate_from_data_scientist(domain: str, context: str, num_records: int = 10):\n", " try:\n", " print(f'Domain: {domain}')\n", " schema = create_domain_schema(domain)\n", " print(schema)\n", " data = generate_data(schema, context, num_records)\n", " data = json.loads(data)\n", " df = pd.DataFrame(data)\n", "\n", " return schema, df\n", " except Exception as e:\n", " return f'Error: {str(e)}', None\n", "\n", "\n", "def generate_from_dynamic_fields(schema_name, context: str, num_fields, num_records: int, *field_values):\n", " try:\n", " fields = []\n", " for i in range(num_fields):\n", " idx = i * 4\n", " if idx + 3 < len(field_values):\n", " name = field_values[idx]\n", " dtype = field_values[idx + 1]\n", " nullable = field_values[idx + 2]\n", " desc = field_values[idx + 3]\n", "\n", " if name and dtype:\n", " fields.append(FieldDescriptor(\n", " name=name,\n", " data_type=dtype,\n", " nullable=nullable if nullable is not None else False,\n", " description=desc if desc else ''\n", " ))\n", "\n", " if not schema_name:\n", " return 'Error: Schema name is required', None\n", "\n", " if not fields:\n", " return 'Error: At least one field is required', None\n", "\n", " schema = Schema(name=schema_name, fields=fields)\n", " data = generate_data(schema.model_dump(), context , num_records)\n", " data = json.loads(data)\n", " df = pd.DataFrame(data)\n", "\n", "\n", " return json.dumps(schema.model_dump(), indent=2), df\n", "\n", " except Exception as e:\n", " return f'Error: {str(e)}', None\n", "\n", "\n", "\n", "title='✨ Coherent Data Generator'\n", "\n", "with gr.Blocks(title=title, theme=gr.themes.Monochrome()) as ui:\n", " gr.Markdown(f'# {title}')\n", " gr.Markdown('Embrass the Coherent Data wins 🏆!')\n", "\n", " df_state = gr.State(value=None)\n", "\n", " with gr.Row():\n", " num_records_input = gr.Number(\n", " label='Number of Records to Generate',\n", " value=10,\n", " minimum=1,\n", " maximum=10000,\n", " step=1,\n", " precision=0\n", " )\n", "\n", " context_input = gr.Textbox(\n", " label='Context',\n", " placeholder='70% Ghana and 30% Nigeria data. Start ID generation from 200',\n", " lines=1\n", " )\n", "\n", " with gr.Tabs() as tabs:\n", " with gr.Tab('Manual Entry', id=0):\n", " schema_name_input = gr.Textbox(label='Schema Name', placeholder='Enter schema name')\n", "\n", " gr.Markdown('### Fields')\n", "\n", " num_fields_state = gr.State(3)\n", "\n", " with gr.Row():\n", " num_fields_slider = gr.Slider(\n", " minimum=1,\n", " maximum=20,\n", " value=3,\n", " step=1,\n", " label='Number of Fields',\n", " interactive=True\n", " )\n", "\n", " gr.HTML('''\n", "
\n", "
Field Name
\n", "
Data Type
\n", "
Nullable
\n", "
Description
\n", "
\n", " ''')\n", "\n", " field_components = []\n", " row_components = []\n", "\n", " for i in range(20):\n", " with gr.Row(visible=(i < 3)) as row:\n", " field_name = gr.Textbox(label='', container=False, scale=2)\n", " data_type = gr.Dropdown(choices=DATA_TYPES, value='string', label='', container=False, scale=2)\n", " nullable = gr.Checkbox(label='', container=False, scale=1)\n", " description = gr.Textbox(label='', container=False, scale=3)\n", "\n", " row_components.append(row)\n", " field_components.extend([field_name, data_type, nullable, description])\n", "\n", " submit_btn = gr.Button('Generate', variant='primary')\n", "\n", " num_fields_slider.change(\n", " fn=lambda x: [gr.update(visible=(i < x)) for i in range(20)],\n", " inputs=[num_fields_slider],\n", " outputs=row_components\n", " )\n", "\n", "\n", " with gr.Tab('SQL', id=1):\n", " gr.Markdown('### Parse SQL DDL')\n", " ddl_input = gr.Code(\n", " value=sql,\n", " label='SQL DDL Statement',\n", " language='sql',\n", " lines=10\n", " )\n", " ddl_btn = gr.Button('Generate', variant='primary')\n", "\n", "\n", " with gr.Tab('>_ Prompt', id=2):\n", " gr.Markdown('### You are on your own here, so be creative 💡')\n", " prompt_input = gr.Textbox(\n", " label='Prompt',\n", " placeholder='Type your prompt',\n", " lines=10\n", " )\n", " prompt_btn = gr.Button('Generate', variant='primary')\n", "\n", " with gr.Tab('Data Scientist 🎩', id=3):\n", " gr.Markdown('### You are on your own here, so be creative 💡')\n", " domain_input = gr.Dropdown(\n", " label='Domain',\n", " choices=['E-commerce Customers', 'Hospital Patients', 'Loan Applications'],\n", " allow_custom_value=True\n", " )\n", "\n", " data_scientist_generate_btn = gr.Button('Generate', variant='primary')\n", "\n", "\n", " with gr.Accordion('Generated Schema', open=False):\n", " output = gr.Code(label='Schema (JSON)', language='json')\n", "\n", " gr.Markdown('## Generated Data')\n", " dataframe_output = gr.Dataframe(\n", " label='',\n", " interactive=False,\n", " wrap=True\n", " )\n", "\n", " gr.Markdown('### Export Data')\n", " with gr.Row():\n", " format_dropdown = gr.Dropdown(\n", " choices=[format.value for format in ExportFormat],\n", " value=ExportFormat.CSV,\n", " label='Export Format',\n", " scale=2\n", " )\n", " download_btn = gr.Button('Download', variant='secondary', scale=1)\n", "\n", " download_file = gr.File(label='Download File', visible=True)\n", "\n", "\n", " def _handle_result(result):\n", " if isinstance(result, tuple) and len(result) == 2:\n", " return result[0], result[1], result[1]\n", " return result[0], result[1], None\n", "\n", "\n", " def update_from_dynamic_fields(schema_name, context, num_fields, num_records, *field_values):\n", " result = generate_from_dynamic_fields(schema_name, context, num_fields, num_records, *field_values)\n", " return _handle_result(result)\n", "\n", "\n", " def update_from_sql(sql: str, context, num_records: int):\n", " result = generate_from_sql(sql, context, num_records)\n", " return _handle_result(result)\n", "\n", "\n", " def update_from_data_scientist(domain: str, context, num_records: int):\n", " result = generate_from_data_scientist(domain, context, num_records)\n", " return _handle_result(result)\n", "\n", "\n", " submit_btn.click(\n", " fn=update_from_dynamic_fields,\n", " inputs=[schema_name_input, context_input, num_fields_slider, num_records_input] + field_components,\n", " outputs=[output, dataframe_output, df_state]\n", " )\n", "\n", " ddl_btn.click(\n", " fn=update_from_sql,\n", " inputs=[ddl_input, context_input, num_records_input],\n", " outputs=[output, dataframe_output, df_state]\n", " )\n", "\n", " data_scientist_generate_btn.click(\n", " fn=update_from_data_scientist,\n", " inputs=[domain_input, context_input, num_records_input],\n", " outputs=[output, dataframe_output, df_state]\n", " )\n", "\n", "\n", " download_btn.click(\n", " fn=prepare_download,\n", " inputs=[df_state, format_dropdown],\n", " outputs=[download_file]\n", " )\n", "\n", "\n", "ui.launch(debug=True)\n" ] } ], "metadata": { "accelerator": "GPU", "colab": { "collapsed_sections": [ "tqSpfJGnme7y" ], "gpuType": "T4", "provenance": [] }, "kernelspec": { "display_name": "Python 3", "name": "python3" }, "language_info": { "name": "python" } }, "nbformat": 4, "nbformat_minor": 0 }