Merge pull request #862 from ranskills/week3-coherent-data-generator
Bootcamp(Ransford): Week 3 - Coherent Data Generator
This commit is contained in:
@@ -0,0 +1,734 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "KbMea_UrO3Ke"
|
||||
},
|
||||
"source": [
|
||||
"# ✨ Coherent Data Generator\n",
|
||||
"\n",
|
||||
"## In real life, data has meaning, relationships, etc., and this is where this tool shines.\n",
|
||||
"\n",
|
||||
"Dependencies between fields are detected, and coherent data is generated.\n",
|
||||
"Example:\n",
|
||||
"When asked to generate data with **Ghana** cited as the context, fields like `name`, `food`, etc., will be Ghanaian. Fields such as phone number will have the appropriate prefix of `+233`, etc.\n",
|
||||
"\n",
|
||||
"This is better than Faker.\n",
|
||||
"\n",
|
||||
"## Steps\n",
|
||||
"Schema -> Generate Data\n",
|
||||
"\n",
|
||||
"Schema Sources: \n",
|
||||
"- Use the guided schema builder\n",
|
||||
"- Bring your own schema from an SQL Data Definition Language (DDL)\n",
|
||||
"- Prompting\n",
|
||||
"- Providing a domain to an old hat to define features for a dataset"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "cN8z-QNlFtYc"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import json\n",
|
||||
"\n",
|
||||
"from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig\n",
|
||||
"import torch\n",
|
||||
"import pandas as pd\n",
|
||||
"\n",
|
||||
"from pydantic import BaseModel, Field\n",
|
||||
"from IPython.display import display, Markdown"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "DOBBN3P2GD2O"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"model_id = \"Qwen/Qwen3-4B-Instruct-2507\"\n",
|
||||
"\n",
|
||||
"device = torch.accelerator.current_accelerator().type if torch.accelerator.is_available() else 'cpu'\n",
|
||||
"print(f'Device: {device}')\n",
|
||||
"\n",
|
||||
"tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)\n",
|
||||
"\n",
|
||||
"model = AutoModelForCausalLM.from_pretrained(\n",
|
||||
" model_id,\n",
|
||||
" dtype=\"auto\",\n",
|
||||
" device_map=\"auto\"\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "HSUebXa1O3MM"
|
||||
},
|
||||
"source": [
|
||||
"## Schema Definitions"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "5LNM76OQjAw6"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# This is for future use where errors in SQL DDL statements can be fixed if the\n",
|
||||
"# specifies that from the UI\n",
|
||||
"class SQLValidationResult(BaseModel):\n",
|
||||
" is_valid: bool\n",
|
||||
" is_fixable: bool\n",
|
||||
" reason: str = Field(default='', description='validation failure reason')\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"class FieldDescriptor(BaseModel):\n",
|
||||
" name: str = Field(..., description='Name of the field')\n",
|
||||
" data_type: str = Field(..., description='Type of the field')\n",
|
||||
" nullable: bool\n",
|
||||
" description: str = Field(..., description='Description of the field')\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"class Schema(BaseModel):\n",
|
||||
" name: str = Field(..., description='Name of the schema')\n",
|
||||
" fields: list[FieldDescriptor] = Field(..., description='List of fields in the schema')"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "6QjitfTBPa1E"
|
||||
},
|
||||
"source": [
|
||||
"## LLM Interactions"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "dXiRHok7Peir"
|
||||
},
|
||||
"source": [
|
||||
"### Generate Content from LLM"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "daTUVG8_PmvM"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def generate(messages: list[dict[str, str]], temperature: float = 0.1) -> any:\n",
|
||||
" text = tokenizer.apply_chat_template(\n",
|
||||
" messages,\n",
|
||||
" tokenize=False,\n",
|
||||
" add_generation_prompt=True,\n",
|
||||
" )\n",
|
||||
" model_inputs = tokenizer([text], return_tensors=\"pt\").to(model.device)\n",
|
||||
"\n",
|
||||
" generated_ids = model.generate(\n",
|
||||
" **model_inputs,\n",
|
||||
" max_new_tokens=16384,\n",
|
||||
" temperature=temperature\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
" output_ids = generated_ids[0][len(model_inputs.input_ids[0]):].tolist()\n",
|
||||
" content = tokenizer.decode(output_ids, skip_special_tokens=True)\n",
|
||||
"\n",
|
||||
" return content"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "sBHJKn8qQhM5"
|
||||
},
|
||||
"source": [
|
||||
"### Generate Data Given A Valid Schema"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "Fla8UQf4Qm5l"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def generate_data(schema: str, context: str = '', num_records: int = 5):\n",
|
||||
" system_prompt = f'''\n",
|
||||
" You are synthetic data generator, you generate data based on the given schema\n",
|
||||
" specific JSON structure.\n",
|
||||
" When a context is provided, intelligently use that to drive the field generation.\n",
|
||||
"\n",
|
||||
" Example:\n",
|
||||
" If Africa is given at the context, fields like name, first_name, last_name, etc.\n",
|
||||
" that can be derived from Africa will be generated.\n",
|
||||
"\n",
|
||||
" If no context is provided, generate data randomly.\n",
|
||||
"\n",
|
||||
" Output an array of JSON objects.\n",
|
||||
" '''\n",
|
||||
"\n",
|
||||
" prompt = f'''\n",
|
||||
" Generate {num_records}:\n",
|
||||
"\n",
|
||||
" Schema:\n",
|
||||
" {schema}\n",
|
||||
"\n",
|
||||
" Context:\n",
|
||||
" {context}\n",
|
||||
" '''\n",
|
||||
"\n",
|
||||
" messages = [\n",
|
||||
" {'role': 'system', 'content': system_prompt},\n",
|
||||
" {\"role\": \"user\", \"content\": prompt}\n",
|
||||
" ]\n",
|
||||
"\n",
|
||||
" return generate(messages)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "izrClU6VPsZp"
|
||||
},
|
||||
"source": [
|
||||
"### SQL"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "aQgY6EK0QPPd"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def sql_validator(ddl: str):\n",
|
||||
" system_prompt = '''\n",
|
||||
" You are an SQL validator, your task is to validate if the given SQL is valid or not.\n",
|
||||
" ONLY return a binary response of 1 and 0. Where 1=valid and 0 = not valid.\n",
|
||||
" '''\n",
|
||||
" prompt = f'Validate: {ddl}'\n",
|
||||
"\n",
|
||||
" messages = [\n",
|
||||
" {'role': 'system', 'content': system_prompt},\n",
|
||||
" {\"role\": \"user\", \"content\": prompt}\n",
|
||||
" ]\n",
|
||||
"\n",
|
||||
" return generate(messages)\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# Future work, this will fix any errors in the SQL DDL statement provided it is\n",
|
||||
"# fixable.\n",
|
||||
"def sql_fixer(ddl: str):\n",
|
||||
" pass\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def parse_ddl(ddl: str):\n",
|
||||
" system_prompt = f'''\n",
|
||||
" You are an SQL analyzer, your task is to extract column information to a\n",
|
||||
" specific JSON structure.\n",
|
||||
"\n",
|
||||
" The output must comform to the following JSON schema:\n",
|
||||
" {Schema.model_json_schema()}\n",
|
||||
" '''\n",
|
||||
" prompt = f'Generate schema for: {ddl}'\n",
|
||||
"\n",
|
||||
" messages = [\n",
|
||||
" {'role': 'system', 'content': system_prompt},\n",
|
||||
" {\"role\": \"user\", \"content\": prompt}\n",
|
||||
" ]\n",
|
||||
"\n",
|
||||
" return generate(messages)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "4mgwDQyDQ1wv"
|
||||
},
|
||||
"source": [
|
||||
"### Data Scientist\n",
|
||||
"\n",
|
||||
"Just give it a domain and you will be amazed the features will give you."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "P36AMvBq8AST"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def create_domain_schema(domain: str):\n",
|
||||
" system_prompt = f'''\n",
|
||||
" You are an expert Data Scientist tasked to describe features for a dataset\n",
|
||||
" aspiring data scientists in a chosen domain.\n",
|
||||
"\n",
|
||||
" Follow these steps EXACTLY:\n",
|
||||
" **Define 6–10 features** for the given domain. Include:\n",
|
||||
" - At least 2 numerical features\n",
|
||||
" - At least 2 categorical features\n",
|
||||
" - 1 boolean or binary feature\n",
|
||||
" - 1 timestamp or date feature\n",
|
||||
" - Realistic dependencies (e.g., \"if loan_amount > 50000, credit_score should be high\")\n",
|
||||
"\n",
|
||||
" Populate your response into the JSON schema below. Strictly out **JSON**\n",
|
||||
" {Schema.model_json_schema()}\n",
|
||||
" '''\n",
|
||||
" prompt = f'Describe the data point. Domain: {domain}'\n",
|
||||
"\n",
|
||||
" messages = [\n",
|
||||
" {'role': 'system', 'content': system_prompt},\n",
|
||||
" {\"role\": \"user\", \"content\": prompt}\n",
|
||||
" ]\n",
|
||||
"\n",
|
||||
" return generate(messages)\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# TODO: Use Gradion Examples to make it easier for the loading of different statements\n",
|
||||
"sql = '''\n",
|
||||
"CREATE TABLE users (\n",
|
||||
" id BIGINT PRIMARY KEY,\n",
|
||||
" name VARCHAR(100) NOT NULL,\n",
|
||||
" email TEXT,\n",
|
||||
" gender ENUM('F', 'M'),\n",
|
||||
" country VARCHAR(100),\n",
|
||||
" mobile_number VARCHAR(100),\n",
|
||||
" created_at TIMESTAMP DEFAULT NOW()\n",
|
||||
");\n",
|
||||
"'''"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "QuVyHOhjDtSH"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"print(f'{model.get_memory_footprint() / 1e9:, .2f} GB')"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "tqSpfJGnme7y"
|
||||
},
|
||||
"source": [
|
||||
"## Export Functions"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "pAu5OPfUmMSm"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from enum import StrEnum\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"class ExportFormat(StrEnum):\n",
|
||||
" CSV = 'CSV'\n",
|
||||
" JSON = 'JSON'\n",
|
||||
" Excel = 'Excel'\n",
|
||||
" Parquet = 'Parquet'\n",
|
||||
" TSV = 'TSV'\n",
|
||||
" HTML = 'HTML'\n",
|
||||
" Markdown = 'Markdown'\n",
|
||||
" SQL = 'SQL'\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def export_data(df, format_type):\n",
|
||||
" if df is None or df.empty:\n",
|
||||
" return None\n",
|
||||
"\n",
|
||||
" try:\n",
|
||||
" if format_type == ExportFormat.CSV:\n",
|
||||
" output = io.StringIO()\n",
|
||||
" df.to_csv(output, index=False)\n",
|
||||
" return output.getvalue()\n",
|
||||
"\n",
|
||||
" elif format_type == ExportFormat.JSON:\n",
|
||||
" return df.to_json(orient='records', indent=2)\n",
|
||||
"\n",
|
||||
" elif format_type == ExportFormat.Excel:\n",
|
||||
" output = io.BytesIO()\n",
|
||||
" df.to_excel(output, index=False, engine='openpyxl')\n",
|
||||
" return output.getvalue()\n",
|
||||
"\n",
|
||||
" elif format_type == ExportFormat.Parquet:\n",
|
||||
" output = io.BytesIO()\n",
|
||||
" df.to_parquet(output, index=False)\n",
|
||||
" return output.getvalue()\n",
|
||||
"\n",
|
||||
" elif format_type == ExportFormat.TSV:\n",
|
||||
" output = io.StringIO()\n",
|
||||
" df.to_csv(output, sep='\\t', index=False)\n",
|
||||
" return output.getvalue()\n",
|
||||
"\n",
|
||||
" elif format_type == ExportFormat.HTML:\n",
|
||||
" return df.to_html(index=False)\n",
|
||||
"\n",
|
||||
" elif format_type == ExportFormat.Markdown:\n",
|
||||
" return df.to_markdown(index=False)\n",
|
||||
"\n",
|
||||
" elif format_type == ExportFormat.SQL:\n",
|
||||
" from sqlalchemy import create_engine\n",
|
||||
" engine = create_engine('sqlite:///:memory:')\n",
|
||||
" table = 'users' # TODO: fix this\n",
|
||||
"\n",
|
||||
" df.to_sql(table, con=engine, index=False)\n",
|
||||
" connection = engine.raw_connection()\n",
|
||||
" sql_statements = list(connection.iterdump())\n",
|
||||
" sql_output_string = \"\\n\".join(sql_statements)\n",
|
||||
" connection.close()\n",
|
||||
"\n",
|
||||
" return sql_output_string\n",
|
||||
"\n",
|
||||
" except Exception as e:\n",
|
||||
" print(f\"Export error: {str(e)}\")\n",
|
||||
" return None\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def prepare_download(df, format_type):\n",
|
||||
" if df is None:\n",
|
||||
" return None\n",
|
||||
"\n",
|
||||
" content = export_data(df, format_type)\n",
|
||||
" if content is None:\n",
|
||||
" return None\n",
|
||||
"\n",
|
||||
" extensions = {\n",
|
||||
" ExportFormat.CSV: '.csv',\n",
|
||||
" ExportFormat.JSON: '.json',\n",
|
||||
" ExportFormat.Excel: '.xlsx',\n",
|
||||
" ExportFormat.Parquet: '.parquet',\n",
|
||||
" ExportFormat.TSV: '.tsv',\n",
|
||||
" ExportFormat.HTML: '.html',\n",
|
||||
" ExportFormat.Markdown: '.md',\n",
|
||||
" ExportFormat.SQL: '.sql',\n",
|
||||
" }\n",
|
||||
"\n",
|
||||
" filename = f'generated_data{extensions.get(format_type, \".txt\")}'\n",
|
||||
"\n",
|
||||
" is_binary_format = format_type in [ExportFormat.Excel, ExportFormat.Parquet]\n",
|
||||
" mode = 'w+b' if is_binary_format else 'w'\n",
|
||||
"\n",
|
||||
" import tempfile\n",
|
||||
" with tempfile.NamedTemporaryFile(mode=mode, delete=False, suffix=extensions[format_type]) as tmp:\n",
|
||||
" tmp.write(content)\n",
|
||||
" tmp.flush()\n",
|
||||
" return tmp.name"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "Q0fZsCuso_YZ"
|
||||
},
|
||||
"source": [
|
||||
"## Gradio UI"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/"
|
||||
},
|
||||
"id": "TJYUWecybDpP",
|
||||
"outputId": "e82d0a13-3ca3-4a01-d45c-78fc94ade9bc"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import gradio as gr\n",
|
||||
"from pydantic import BaseModel, Field\n",
|
||||
"import json\n",
|
||||
"import pandas as pd\n",
|
||||
"import io\n",
|
||||
"\n",
|
||||
"DATA_TYPES = ['string', 'integer', 'float', 'boolean', 'date', 'datetime', 'array', 'object']\n",
|
||||
"\n",
|
||||
"def generate_from_sql(sql: str, context: str, num_records: int = 10):\n",
|
||||
" try:\n",
|
||||
" print(f'SQL: {sql}')\n",
|
||||
" schema = parse_ddl(sql)\n",
|
||||
" data = generate_data(schema, context, num_records)\n",
|
||||
"\n",
|
||||
" data = json.loads(data)\n",
|
||||
" df = pd.DataFrame(data)\n",
|
||||
"\n",
|
||||
" return schema, df\n",
|
||||
" except Exception as e:\n",
|
||||
" return f'Error: {str(e)}', None\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def generate_from_data_scientist(domain: str, context: str, num_records: int = 10):\n",
|
||||
" try:\n",
|
||||
" print(f'Domain: {domain}')\n",
|
||||
" schema = create_domain_schema(domain)\n",
|
||||
" print(schema)\n",
|
||||
" data = generate_data(schema, context, num_records)\n",
|
||||
" data = json.loads(data)\n",
|
||||
" df = pd.DataFrame(data)\n",
|
||||
"\n",
|
||||
" return schema, df\n",
|
||||
" except Exception as e:\n",
|
||||
" return f'Error: {str(e)}', None\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def generate_from_dynamic_fields(schema_name, context: str, num_fields, num_records: int, *field_values):\n",
|
||||
" try:\n",
|
||||
" fields = []\n",
|
||||
" for i in range(num_fields):\n",
|
||||
" idx = i * 4\n",
|
||||
" if idx + 3 < len(field_values):\n",
|
||||
" name = field_values[idx]\n",
|
||||
" dtype = field_values[idx + 1]\n",
|
||||
" nullable = field_values[idx + 2]\n",
|
||||
" desc = field_values[idx + 3]\n",
|
||||
"\n",
|
||||
" if name and dtype:\n",
|
||||
" fields.append(FieldDescriptor(\n",
|
||||
" name=name,\n",
|
||||
" data_type=dtype,\n",
|
||||
" nullable=nullable if nullable is not None else False,\n",
|
||||
" description=desc if desc else ''\n",
|
||||
" ))\n",
|
||||
"\n",
|
||||
" if not schema_name:\n",
|
||||
" return 'Error: Schema name is required', None\n",
|
||||
"\n",
|
||||
" if not fields:\n",
|
||||
" return 'Error: At least one field is required', None\n",
|
||||
"\n",
|
||||
" schema = Schema(name=schema_name, fields=fields)\n",
|
||||
" data = generate_data(schema.model_dump(), context , num_records)\n",
|
||||
" data = json.loads(data)\n",
|
||||
" df = pd.DataFrame(data)\n",
|
||||
"\n",
|
||||
"\n",
|
||||
" return json.dumps(schema.model_dump(), indent=2), df\n",
|
||||
"\n",
|
||||
" except Exception as e:\n",
|
||||
" return f'Error: {str(e)}', None\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"title='✨ Coherent Data Generator'\n",
|
||||
"\n",
|
||||
"with gr.Blocks(title=title, theme=gr.themes.Monochrome()) as ui:\n",
|
||||
" gr.Markdown(f'# {title}')\n",
|
||||
" gr.Markdown('Embrass the Coherent Data wins 🏆!')\n",
|
||||
"\n",
|
||||
" df_state = gr.State(value=None)\n",
|
||||
"\n",
|
||||
" with gr.Row():\n",
|
||||
" num_records_input = gr.Number(\n",
|
||||
" label='Number of Records to Generate',\n",
|
||||
" value=10,\n",
|
||||
" minimum=1,\n",
|
||||
" maximum=10000,\n",
|
||||
" step=1,\n",
|
||||
" precision=0\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
" context_input = gr.Textbox(\n",
|
||||
" label='Context',\n",
|
||||
" placeholder='70% Ghana and 30% Nigeria data. Start ID generation from 200',\n",
|
||||
" lines=1\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
" with gr.Tabs() as tabs:\n",
|
||||
" with gr.Tab('Manual Entry', id=0):\n",
|
||||
" schema_name_input = gr.Textbox(label='Schema Name', placeholder='Enter schema name')\n",
|
||||
"\n",
|
||||
" gr.Markdown('### Fields')\n",
|
||||
"\n",
|
||||
" num_fields_state = gr.State(3)\n",
|
||||
"\n",
|
||||
" with gr.Row():\n",
|
||||
" num_fields_slider = gr.Slider(\n",
|
||||
" minimum=1,\n",
|
||||
" maximum=20,\n",
|
||||
" value=3,\n",
|
||||
" step=1,\n",
|
||||
" label='Number of Fields',\n",
|
||||
" interactive=True\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
" gr.HTML('''\n",
|
||||
" <div style=\"display: flex; gap: 8px; margin-bottom: 8px; font-weight: bold;\">\n",
|
||||
" <div style=\"flex: 2;\">Field Name</div>\n",
|
||||
" <div style=\"flex: 2;\">Data Type</div>\n",
|
||||
" <div style=\"flex: 1;\">Nullable</div>\n",
|
||||
" <div style=\"flex: 3;\">Description</div>\n",
|
||||
" </div>\n",
|
||||
" ''')\n",
|
||||
"\n",
|
||||
" field_components = []\n",
|
||||
" row_components = []\n",
|
||||
"\n",
|
||||
" for i in range(20):\n",
|
||||
" with gr.Row(visible=(i < 3)) as row:\n",
|
||||
" field_name = gr.Textbox(label='', container=False, scale=2)\n",
|
||||
" data_type = gr.Dropdown(choices=DATA_TYPES, value='string', label='', container=False, scale=2)\n",
|
||||
" nullable = gr.Checkbox(label='', container=False, scale=1)\n",
|
||||
" description = gr.Textbox(label='', container=False, scale=3)\n",
|
||||
"\n",
|
||||
" row_components.append(row)\n",
|
||||
" field_components.extend([field_name, data_type, nullable, description])\n",
|
||||
"\n",
|
||||
" submit_btn = gr.Button('Generate', variant='primary')\n",
|
||||
"\n",
|
||||
" num_fields_slider.change(\n",
|
||||
" fn=lambda x: [gr.update(visible=(i < x)) for i in range(20)],\n",
|
||||
" inputs=[num_fields_slider],\n",
|
||||
" outputs=row_components\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
"\n",
|
||||
" with gr.Tab('SQL', id=1):\n",
|
||||
" gr.Markdown('### Parse SQL DDL')\n",
|
||||
" ddl_input = gr.Code(\n",
|
||||
" value=sql,\n",
|
||||
" label='SQL DDL Statement',\n",
|
||||
" language='sql',\n",
|
||||
" lines=10\n",
|
||||
" )\n",
|
||||
" ddl_btn = gr.Button('Generate', variant='primary')\n",
|
||||
"\n",
|
||||
"\n",
|
||||
" with gr.Tab('>_ Prompt', id=2):\n",
|
||||
" gr.Markdown('### You are on your own here, so be creative 💡')\n",
|
||||
" prompt_input = gr.Textbox(\n",
|
||||
" label='Prompt',\n",
|
||||
" placeholder='Type your prompt',\n",
|
||||
" lines=10\n",
|
||||
" )\n",
|
||||
" prompt_btn = gr.Button('Generate', variant='primary')\n",
|
||||
"\n",
|
||||
" with gr.Tab('Data Scientist 🎩', id=3):\n",
|
||||
" gr.Markdown('### You are on your own here, so be creative 💡')\n",
|
||||
" domain_input = gr.Dropdown(\n",
|
||||
" label='Domain',\n",
|
||||
" choices=['E-commerce Customers', 'Hospital Patients', 'Loan Applications'],\n",
|
||||
" allow_custom_value=True\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
" data_scientist_generate_btn = gr.Button('Generate', variant='primary')\n",
|
||||
"\n",
|
||||
"\n",
|
||||
" with gr.Accordion('Generated Schema', open=False):\n",
|
||||
" output = gr.Code(label='Schema (JSON)', language='json')\n",
|
||||
"\n",
|
||||
" gr.Markdown('## Generated Data')\n",
|
||||
" dataframe_output = gr.Dataframe(\n",
|
||||
" label='',\n",
|
||||
" interactive=False,\n",
|
||||
" wrap=True\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
" gr.Markdown('### Export Data')\n",
|
||||
" with gr.Row():\n",
|
||||
" format_dropdown = gr.Dropdown(\n",
|
||||
" choices=[format.value for format in ExportFormat],\n",
|
||||
" value=ExportFormat.CSV,\n",
|
||||
" label='Export Format',\n",
|
||||
" scale=2\n",
|
||||
" )\n",
|
||||
" download_btn = gr.Button('Download', variant='secondary', scale=1)\n",
|
||||
"\n",
|
||||
" download_file = gr.File(label='Download File', visible=True)\n",
|
||||
"\n",
|
||||
"\n",
|
||||
" def _handle_result(result):\n",
|
||||
" if isinstance(result, tuple) and len(result) == 2:\n",
|
||||
" return result[0], result[1], result[1]\n",
|
||||
" return result[0], result[1], None\n",
|
||||
"\n",
|
||||
"\n",
|
||||
" def update_from_dynamic_fields(schema_name, context, num_fields, num_records, *field_values):\n",
|
||||
" result = generate_from_dynamic_fields(schema_name, context, num_fields, num_records, *field_values)\n",
|
||||
" return _handle_result(result)\n",
|
||||
"\n",
|
||||
"\n",
|
||||
" def update_from_sql(sql: str, context, num_records: int):\n",
|
||||
" result = generate_from_sql(sql, context, num_records)\n",
|
||||
" return _handle_result(result)\n",
|
||||
"\n",
|
||||
"\n",
|
||||
" def update_from_data_scientist(domain: str, context, num_records: int):\n",
|
||||
" result = generate_from_data_scientist(domain, context, num_records)\n",
|
||||
" return _handle_result(result)\n",
|
||||
"\n",
|
||||
"\n",
|
||||
" submit_btn.click(\n",
|
||||
" fn=update_from_dynamic_fields,\n",
|
||||
" inputs=[schema_name_input, context_input, num_fields_slider, num_records_input] + field_components,\n",
|
||||
" outputs=[output, dataframe_output, df_state]\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
" ddl_btn.click(\n",
|
||||
" fn=update_from_sql,\n",
|
||||
" inputs=[ddl_input, context_input, num_records_input],\n",
|
||||
" outputs=[output, dataframe_output, df_state]\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
" data_scientist_generate_btn.click(\n",
|
||||
" fn=update_from_data_scientist,\n",
|
||||
" inputs=[domain_input, context_input, num_records_input],\n",
|
||||
" outputs=[output, dataframe_output, df_state]\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
"\n",
|
||||
" download_btn.click(\n",
|
||||
" fn=prepare_download,\n",
|
||||
" inputs=[df_state, format_dropdown],\n",
|
||||
" outputs=[download_file]\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"ui.launch(debug=True)\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"accelerator": "GPU",
|
||||
"colab": {
|
||||
"collapsed_sections": [
|
||||
"tqSpfJGnme7y"
|
||||
],
|
||||
"gpuType": "T4",
|
||||
"provenance": []
|
||||
},
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"name": "python"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 0
|
||||
}
|
||||
Reference in New Issue
Block a user