Files
LLM_Engineering_OLD/week3/community-contributions/ranskills-week3-coherent-data-generator.ipynb
2025-10-27 11:42:28 +00:00

735 lines
26 KiB
Plaintext
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "KbMea_UrO3Ke"
},
"source": [
"# ✨ Coherent Data Generator\n",
"\n",
"## In real life, data has meaning, relationships, etc., and this is where this tool shines.\n",
"\n",
"Dependencies between fields are detected, and coherent data is generated.\n",
"Example:\n",
"When asked to generate data with **Ghana** cited as the context, fields like `name`, `food`, etc., will be Ghanaian. Fields such as phone number will have the appropriate prefix of `+233`, etc.\n",
"\n",
"This is better than Faker.\n",
"\n",
"## Steps\n",
"Schema -> Generate Data\n",
"\n",
"Schema Sources: \n",
"- Use the guided schema builder\n",
"- Bring your own schema from an SQL Data Definition Language (DDL)\n",
"- Prompting\n",
"- Providing a domain to an old hat to define features for a dataset"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "cN8z-QNlFtYc"
},
"outputs": [],
"source": [
"import json\n",
"\n",
"from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig\n",
"import torch\n",
"import pandas as pd\n",
"\n",
"from pydantic import BaseModel, Field\n",
"from IPython.display import display, Markdown"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "DOBBN3P2GD2O"
},
"outputs": [],
"source": [
"model_id = \"Qwen/Qwen3-4B-Instruct-2507\"\n",
"\n",
"device = torch.accelerator.current_accelerator().type if torch.accelerator.is_available() else 'cpu'\n",
"print(f'Device: {device}')\n",
"\n",
"tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)\n",
"\n",
"model = AutoModelForCausalLM.from_pretrained(\n",
" model_id,\n",
" dtype=\"auto\",\n",
" device_map=\"auto\"\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "HSUebXa1O3MM"
},
"source": [
"## Schema Definitions"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "5LNM76OQjAw6"
},
"outputs": [],
"source": [
"# This is for future use where errors in SQL DDL statements can be fixed if the\n",
"# specifies that from the UI\n",
"class SQLValidationResult(BaseModel):\n",
" is_valid: bool\n",
" is_fixable: bool\n",
" reason: str = Field(default='', description='validation failure reason')\n",
"\n",
"\n",
"class FieldDescriptor(BaseModel):\n",
" name: str = Field(..., description='Name of the field')\n",
" data_type: str = Field(..., description='Type of the field')\n",
" nullable: bool\n",
" description: str = Field(..., description='Description of the field')\n",
"\n",
"\n",
"class Schema(BaseModel):\n",
" name: str = Field(..., description='Name of the schema')\n",
" fields: list[FieldDescriptor] = Field(..., description='List of fields in the schema')"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "6QjitfTBPa1E"
},
"source": [
"## LLM Interactions"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "dXiRHok7Peir"
},
"source": [
"### Generate Content from LLM"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "daTUVG8_PmvM"
},
"outputs": [],
"source": [
"def generate(messages: list[dict[str, str]], temperature: float = 0.1) -> any:\n",
" text = tokenizer.apply_chat_template(\n",
" messages,\n",
" tokenize=False,\n",
" add_generation_prompt=True,\n",
" )\n",
" model_inputs = tokenizer([text], return_tensors=\"pt\").to(model.device)\n",
"\n",
" generated_ids = model.generate(\n",
" **model_inputs,\n",
" max_new_tokens=16384,\n",
" temperature=temperature\n",
" )\n",
"\n",
" output_ids = generated_ids[0][len(model_inputs.input_ids[0]):].tolist()\n",
" content = tokenizer.decode(output_ids, skip_special_tokens=True)\n",
"\n",
" return content"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "sBHJKn8qQhM5"
},
"source": [
"### Generate Data Given A Valid Schema"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Fla8UQf4Qm5l"
},
"outputs": [],
"source": [
"def generate_data(schema: str, context: str = '', num_records: int = 5):\n",
" system_prompt = f'''\n",
" You are synthetic data generator, you generate data based on the given schema\n",
" specific JSON structure.\n",
" When a context is provided, intelligently use that to drive the field generation.\n",
"\n",
" Example:\n",
" If Africa is given at the context, fields like name, first_name, last_name, etc.\n",
" that can be derived from Africa will be generated.\n",
"\n",
" If no context is provided, generate data randomly.\n",
"\n",
" Output an array of JSON objects.\n",
" '''\n",
"\n",
" prompt = f'''\n",
" Generate {num_records}:\n",
"\n",
" Schema:\n",
" {schema}\n",
"\n",
" Context:\n",
" {context}\n",
" '''\n",
"\n",
" messages = [\n",
" {'role': 'system', 'content': system_prompt},\n",
" {\"role\": \"user\", \"content\": prompt}\n",
" ]\n",
"\n",
" return generate(messages)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "izrClU6VPsZp"
},
"source": [
"### SQL"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "aQgY6EK0QPPd"
},
"outputs": [],
"source": [
"def sql_validator(ddl: str):\n",
" system_prompt = '''\n",
" You are an SQL validator, your task is to validate if the given SQL is valid or not.\n",
" ONLY return a binary response of 1 and 0. Where 1=valid and 0 = not valid.\n",
" '''\n",
" prompt = f'Validate: {ddl}'\n",
"\n",
" messages = [\n",
" {'role': 'system', 'content': system_prompt},\n",
" {\"role\": \"user\", \"content\": prompt}\n",
" ]\n",
"\n",
" return generate(messages)\n",
"\n",
"\n",
"# Future work, this will fix any errors in the SQL DDL statement provided it is\n",
"# fixable.\n",
"def sql_fixer(ddl: str):\n",
" pass\n",
"\n",
"\n",
"def parse_ddl(ddl: str):\n",
" system_prompt = f'''\n",
" You are an SQL analyzer, your task is to extract column information to a\n",
" specific JSON structure.\n",
"\n",
" The output must comform to the following JSON schema:\n",
" {Schema.model_json_schema()}\n",
" '''\n",
" prompt = f'Generate schema for: {ddl}'\n",
"\n",
" messages = [\n",
" {'role': 'system', 'content': system_prompt},\n",
" {\"role\": \"user\", \"content\": prompt}\n",
" ]\n",
"\n",
" return generate(messages)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "4mgwDQyDQ1wv"
},
"source": [
"### Data Scientist\n",
"\n",
"Just give it a domain and you will be amazed the features will give you."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "P36AMvBq8AST"
},
"outputs": [],
"source": [
"def create_domain_schema(domain: str):\n",
" system_prompt = f'''\n",
" You are an expert Data Scientist tasked to describe features for a dataset\n",
" aspiring data scientists in a chosen domain.\n",
"\n",
" Follow these steps EXACTLY:\n",
" **Define 610 features** for the given domain. Include:\n",
" - At least 2 numerical features\n",
" - At least 2 categorical features\n",
" - 1 boolean or binary feature\n",
" - 1 timestamp or date feature\n",
" - Realistic dependencies (e.g., \"if loan_amount > 50000, credit_score should be high\")\n",
"\n",
" Populate your response into the JSON schema below. Strictly out **JSON**\n",
" {Schema.model_json_schema()}\n",
" '''\n",
" prompt = f'Describe the data point. Domain: {domain}'\n",
"\n",
" messages = [\n",
" {'role': 'system', 'content': system_prompt},\n",
" {\"role\": \"user\", \"content\": prompt}\n",
" ]\n",
"\n",
" return generate(messages)\n",
"\n",
"\n",
"# TODO: Use Gradion Examples to make it easier for the loading of different statements\n",
"sql = '''\n",
"CREATE TABLE users (\n",
" id BIGINT PRIMARY KEY,\n",
" name VARCHAR(100) NOT NULL,\n",
" email TEXT,\n",
" gender ENUM('F', 'M'),\n",
" country VARCHAR(100),\n",
" mobile_number VARCHAR(100),\n",
" created_at TIMESTAMP DEFAULT NOW()\n",
");\n",
"'''"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "QuVyHOhjDtSH"
},
"outputs": [],
"source": [
"print(f'{model.get_memory_footprint() / 1e9:, .2f} GB')"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "tqSpfJGnme7y"
},
"source": [
"## Export Functions"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "pAu5OPfUmMSm"
},
"outputs": [],
"source": [
"from enum import StrEnum\n",
"\n",
"\n",
"class ExportFormat(StrEnum):\n",
" CSV = 'CSV'\n",
" JSON = 'JSON'\n",
" Excel = 'Excel'\n",
" Parquet = 'Parquet'\n",
" TSV = 'TSV'\n",
" HTML = 'HTML'\n",
" Markdown = 'Markdown'\n",
" SQL = 'SQL'\n",
"\n",
"\n",
"def export_data(df, format_type):\n",
" if df is None or df.empty:\n",
" return None\n",
"\n",
" try:\n",
" if format_type == ExportFormat.CSV:\n",
" output = io.StringIO()\n",
" df.to_csv(output, index=False)\n",
" return output.getvalue()\n",
"\n",
" elif format_type == ExportFormat.JSON:\n",
" return df.to_json(orient='records', indent=2)\n",
"\n",
" elif format_type == ExportFormat.Excel:\n",
" output = io.BytesIO()\n",
" df.to_excel(output, index=False, engine='openpyxl')\n",
" return output.getvalue()\n",
"\n",
" elif format_type == ExportFormat.Parquet:\n",
" output = io.BytesIO()\n",
" df.to_parquet(output, index=False)\n",
" return output.getvalue()\n",
"\n",
" elif format_type == ExportFormat.TSV:\n",
" output = io.StringIO()\n",
" df.to_csv(output, sep='\\t', index=False)\n",
" return output.getvalue()\n",
"\n",
" elif format_type == ExportFormat.HTML:\n",
" return df.to_html(index=False)\n",
"\n",
" elif format_type == ExportFormat.Markdown:\n",
" return df.to_markdown(index=False)\n",
"\n",
" elif format_type == ExportFormat.SQL:\n",
" from sqlalchemy import create_engine\n",
" engine = create_engine('sqlite:///:memory:')\n",
" table = 'users' # TODO: fix this\n",
"\n",
" df.to_sql(table, con=engine, index=False)\n",
" connection = engine.raw_connection()\n",
" sql_statements = list(connection.iterdump())\n",
" sql_output_string = \"\\n\".join(sql_statements)\n",
" connection.close()\n",
"\n",
" return sql_output_string\n",
"\n",
" except Exception as e:\n",
" print(f\"Export error: {str(e)}\")\n",
" return None\n",
"\n",
"\n",
"def prepare_download(df, format_type):\n",
" if df is None:\n",
" return None\n",
"\n",
" content = export_data(df, format_type)\n",
" if content is None:\n",
" return None\n",
"\n",
" extensions = {\n",
" ExportFormat.CSV: '.csv',\n",
" ExportFormat.JSON: '.json',\n",
" ExportFormat.Excel: '.xlsx',\n",
" ExportFormat.Parquet: '.parquet',\n",
" ExportFormat.TSV: '.tsv',\n",
" ExportFormat.HTML: '.html',\n",
" ExportFormat.Markdown: '.md',\n",
" ExportFormat.SQL: '.sql',\n",
" }\n",
"\n",
" filename = f'generated_data{extensions.get(format_type, \".txt\")}'\n",
"\n",
" is_binary_format = format_type in [ExportFormat.Excel, ExportFormat.Parquet]\n",
" mode = 'w+b' if is_binary_format else 'w'\n",
"\n",
" import tempfile\n",
" with tempfile.NamedTemporaryFile(mode=mode, delete=False, suffix=extensions[format_type]) as tmp:\n",
" tmp.write(content)\n",
" tmp.flush()\n",
" return tmp.name"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Q0fZsCuso_YZ"
},
"source": [
"## Gradio UI"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "TJYUWecybDpP",
"outputId": "e82d0a13-3ca3-4a01-d45c-78fc94ade9bc"
},
"outputs": [],
"source": [
"import gradio as gr\n",
"from pydantic import BaseModel, Field\n",
"import json\n",
"import pandas as pd\n",
"import io\n",
"\n",
"DATA_TYPES = ['string', 'integer', 'float', 'boolean', 'date', 'datetime', 'array', 'object']\n",
"\n",
"def generate_from_sql(sql: str, context: str, num_records: int = 10):\n",
" try:\n",
" print(f'SQL: {sql}')\n",
" schema = parse_ddl(sql)\n",
" data = generate_data(schema, context, num_records)\n",
"\n",
" data = json.loads(data)\n",
" df = pd.DataFrame(data)\n",
"\n",
" return schema, df\n",
" except Exception as e:\n",
" return f'Error: {str(e)}', None\n",
"\n",
"\n",
"def generate_from_data_scientist(domain: str, context: str, num_records: int = 10):\n",
" try:\n",
" print(f'Domain: {domain}')\n",
" schema = create_domain_schema(domain)\n",
" print(schema)\n",
" data = generate_data(schema, context, num_records)\n",
" data = json.loads(data)\n",
" df = pd.DataFrame(data)\n",
"\n",
" return schema, df\n",
" except Exception as e:\n",
" return f'Error: {str(e)}', None\n",
"\n",
"\n",
"def generate_from_dynamic_fields(schema_name, context: str, num_fields, num_records: int, *field_values):\n",
" try:\n",
" fields = []\n",
" for i in range(num_fields):\n",
" idx = i * 4\n",
" if idx + 3 < len(field_values):\n",
" name = field_values[idx]\n",
" dtype = field_values[idx + 1]\n",
" nullable = field_values[idx + 2]\n",
" desc = field_values[idx + 3]\n",
"\n",
" if name and dtype:\n",
" fields.append(FieldDescriptor(\n",
" name=name,\n",
" data_type=dtype,\n",
" nullable=nullable if nullable is not None else False,\n",
" description=desc if desc else ''\n",
" ))\n",
"\n",
" if not schema_name:\n",
" return 'Error: Schema name is required', None\n",
"\n",
" if not fields:\n",
" return 'Error: At least one field is required', None\n",
"\n",
" schema = Schema(name=schema_name, fields=fields)\n",
" data = generate_data(schema.model_dump(), context , num_records)\n",
" data = json.loads(data)\n",
" df = pd.DataFrame(data)\n",
"\n",
"\n",
" return json.dumps(schema.model_dump(), indent=2), df\n",
"\n",
" except Exception as e:\n",
" return f'Error: {str(e)}', None\n",
"\n",
"\n",
"\n",
"title='✨ Coherent Data Generator'\n",
"\n",
"with gr.Blocks(title=title, theme=gr.themes.Monochrome()) as ui:\n",
" gr.Markdown(f'# {title}')\n",
" gr.Markdown('Embrass the Coherent Data wins 🏆!')\n",
"\n",
" df_state = gr.State(value=None)\n",
"\n",
" with gr.Row():\n",
" num_records_input = gr.Number(\n",
" label='Number of Records to Generate',\n",
" value=10,\n",
" minimum=1,\n",
" maximum=10000,\n",
" step=1,\n",
" precision=0\n",
" )\n",
"\n",
" context_input = gr.Textbox(\n",
" label='Context',\n",
" placeholder='70% Ghana and 30% Nigeria data. Start ID generation from 200',\n",
" lines=1\n",
" )\n",
"\n",
" with gr.Tabs() as tabs:\n",
" with gr.Tab('Manual Entry', id=0):\n",
" schema_name_input = gr.Textbox(label='Schema Name', placeholder='Enter schema name')\n",
"\n",
" gr.Markdown('### Fields')\n",
"\n",
" num_fields_state = gr.State(3)\n",
"\n",
" with gr.Row():\n",
" num_fields_slider = gr.Slider(\n",
" minimum=1,\n",
" maximum=20,\n",
" value=3,\n",
" step=1,\n",
" label='Number of Fields',\n",
" interactive=True\n",
" )\n",
"\n",
" gr.HTML('''\n",
" <div style=\"display: flex; gap: 8px; margin-bottom: 8px; font-weight: bold;\">\n",
" <div style=\"flex: 2;\">Field Name</div>\n",
" <div style=\"flex: 2;\">Data Type</div>\n",
" <div style=\"flex: 1;\">Nullable</div>\n",
" <div style=\"flex: 3;\">Description</div>\n",
" </div>\n",
" ''')\n",
"\n",
" field_components = []\n",
" row_components = []\n",
"\n",
" for i in range(20):\n",
" with gr.Row(visible=(i < 3)) as row:\n",
" field_name = gr.Textbox(label='', container=False, scale=2)\n",
" data_type = gr.Dropdown(choices=DATA_TYPES, value='string', label='', container=False, scale=2)\n",
" nullable = gr.Checkbox(label='', container=False, scale=1)\n",
" description = gr.Textbox(label='', container=False, scale=3)\n",
"\n",
" row_components.append(row)\n",
" field_components.extend([field_name, data_type, nullable, description])\n",
"\n",
" submit_btn = gr.Button('Generate', variant='primary')\n",
"\n",
" num_fields_slider.change(\n",
" fn=lambda x: [gr.update(visible=(i < x)) for i in range(20)],\n",
" inputs=[num_fields_slider],\n",
" outputs=row_components\n",
" )\n",
"\n",
"\n",
" with gr.Tab('SQL', id=1):\n",
" gr.Markdown('### Parse SQL DDL')\n",
" ddl_input = gr.Code(\n",
" value=sql,\n",
" label='SQL DDL Statement',\n",
" language='sql',\n",
" lines=10\n",
" )\n",
" ddl_btn = gr.Button('Generate', variant='primary')\n",
"\n",
"\n",
" with gr.Tab('>_ Prompt', id=2):\n",
" gr.Markdown('### You are on your own here, so be creative 💡')\n",
" prompt_input = gr.Textbox(\n",
" label='Prompt',\n",
" placeholder='Type your prompt',\n",
" lines=10\n",
" )\n",
" prompt_btn = gr.Button('Generate', variant='primary')\n",
"\n",
" with gr.Tab('Data Scientist 🎩', id=3):\n",
" gr.Markdown('### You are on your own here, so be creative 💡')\n",
" domain_input = gr.Dropdown(\n",
" label='Domain',\n",
" choices=['E-commerce Customers', 'Hospital Patients', 'Loan Applications'],\n",
" allow_custom_value=True\n",
" )\n",
"\n",
" data_scientist_generate_btn = gr.Button('Generate', variant='primary')\n",
"\n",
"\n",
" with gr.Accordion('Generated Schema', open=False):\n",
" output = gr.Code(label='Schema (JSON)', language='json')\n",
"\n",
" gr.Markdown('## Generated Data')\n",
" dataframe_output = gr.Dataframe(\n",
" label='',\n",
" interactive=False,\n",
" wrap=True\n",
" )\n",
"\n",
" gr.Markdown('### Export Data')\n",
" with gr.Row():\n",
" format_dropdown = gr.Dropdown(\n",
" choices=[format.value for format in ExportFormat],\n",
" value=ExportFormat.CSV,\n",
" label='Export Format',\n",
" scale=2\n",
" )\n",
" download_btn = gr.Button('Download', variant='secondary', scale=1)\n",
"\n",
" download_file = gr.File(label='Download File', visible=True)\n",
"\n",
"\n",
" def _handle_result(result):\n",
" if isinstance(result, tuple) and len(result) == 2:\n",
" return result[0], result[1], result[1]\n",
" return result[0], result[1], None\n",
"\n",
"\n",
" def update_from_dynamic_fields(schema_name, context, num_fields, num_records, *field_values):\n",
" result = generate_from_dynamic_fields(schema_name, context, num_fields, num_records, *field_values)\n",
" return _handle_result(result)\n",
"\n",
"\n",
" def update_from_sql(sql: str, context, num_records: int):\n",
" result = generate_from_sql(sql, context, num_records)\n",
" return _handle_result(result)\n",
"\n",
"\n",
" def update_from_data_scientist(domain: str, context, num_records: int):\n",
" result = generate_from_data_scientist(domain, context, num_records)\n",
" return _handle_result(result)\n",
"\n",
"\n",
" submit_btn.click(\n",
" fn=update_from_dynamic_fields,\n",
" inputs=[schema_name_input, context_input, num_fields_slider, num_records_input] + field_components,\n",
" outputs=[output, dataframe_output, df_state]\n",
" )\n",
"\n",
" ddl_btn.click(\n",
" fn=update_from_sql,\n",
" inputs=[ddl_input, context_input, num_records_input],\n",
" outputs=[output, dataframe_output, df_state]\n",
" )\n",
"\n",
" data_scientist_generate_btn.click(\n",
" fn=update_from_data_scientist,\n",
" inputs=[domain_input, context_input, num_records_input],\n",
" outputs=[output, dataframe_output, df_state]\n",
" )\n",
"\n",
"\n",
" download_btn.click(\n",
" fn=prepare_download,\n",
" inputs=[df_state, format_dropdown],\n",
" outputs=[download_file]\n",
" )\n",
"\n",
"\n",
"ui.launch(debug=True)\n"
]
}
],
"metadata": {
"accelerator": "GPU",
"colab": {
"collapsed_sections": [
"tqSpfJGnme7y"
],
"gpuType": "T4",
"provenance": []
},
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
},
"language_info": {
"name": "python"
}
},
"nbformat": 4,
"nbformat_minor": 0
}