This commit is contained in:
Dmitry Kisselev
2025-10-19 15:18:53 -07:00
parent f0718b6512
commit 5cbf627469

View File

@@ -44,6 +44,12 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null,
"id": "m-yhYlN4OQEC",
"metadata": {
"id": "m-yhYlN4OQEC"
},
"outputs": [],
"source": [ "source": [
"gpu_info = !nvidia-smi\n", "gpu_info = !nvidia-smi\n",
"gpu_info = '\\n'.join(gpu_info)\n", "gpu_info = '\\n'.join(gpu_info)\n",
@@ -55,13 +61,7 @@
" print(\"Success - Connected to a T4\")\n", " print(\"Success - Connected to a T4\")\n",
" else:\n", " else:\n",
" print(\"NOT CONNECTED TO A T4\")" " print(\"NOT CONNECTED TO A T4\")"
], ]
"metadata": {
"id": "m-yhYlN4OQEC"
},
"id": "m-yhYlN4OQEC",
"execution_count": null,
"outputs": []
}, },
{ {
"cell_type": "code", "cell_type": "code",
@@ -80,20 +80,16 @@
"import re\n", "import re\n",
"import gc\n", "import gc\n",
"import torch\n", "import torch\n",
"from typing import List, Dict, Any, Optional, Tuple\n", "from typing import List, Dict, Any, Tuple\n",
"from pathlib import Path\n",
"import warnings\n", "import warnings\n",
"warnings.filterwarnings(\"ignore\")\n", "warnings.filterwarnings(\"ignore\")\n",
"\n", "\n",
"# LLM APIs\n", "# LLM APIs\n",
"from openai import OpenAI\n", "from openai import OpenAI\n",
"# import anthropic\n",
"# import google.generativeai as genai\n",
"# from deepseek import DeepSeek\n",
"\n", "\n",
"# HuggingFace\n", "# HuggingFace\n",
"from huggingface_hub import login\n", "from huggingface_hub import login\n",
"from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, TextStreamer\n", "from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig\n",
"\n", "\n",
"# Data processing\n", "# Data processing\n",
"import nltk\n", "import nltk\n",
@@ -203,45 +199,52 @@
"HUGGINGFACE_MODELS = {\n", "HUGGINGFACE_MODELS = {\n",
" \"Llama 3.1 8B\": {\n", " \"Llama 3.1 8B\": {\n",
" \"model_id\": \"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n", " \"model_id\": \"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n",
" \"description\": \"8B model - that is good for structured data generation\",\n", " \"description\": \"Good for structured data generation\",\n",
" \"size\": \"8B\",\n", " \"size\": \"8B\",\n",
" \"type\": \"huggingface\"\n", " \"type\": \"huggingface\",\n",
" \"model_class\": \"LlamaForCausalLM\"\n",
" },\n", " },\n",
" \"Llama 3.2 3B\": {\n", " \"Llama 3.2 3B\": {\n",
" \"model_id\": \"meta-llama/Llama-3.2-3B-Instruct\",\n", " \"model_id\": \"meta-llama/Llama-3.2-3B-Instruct\",\n",
" \"description\": \"3B model - smaller and faster model that is good for simple schemas\",\n", " \"description\": \"Smaller and faster model for simple schemas\",\n",
" \"size\": \"3B\",\n", " \"size\": \"3B\",\n",
" \"type\": \"huggingface\"\n", " \"type\": \"huggingface\",\n",
" \"model_class\": \"LlamaForCausalLM\"\n",
" },\n", " },\n",
" \"Phi-3.5 Mini\": {\n", " \"Phi-3.5 Mini\": {\n",
" \"model_id\": \"microsoft/Phi-3.5-mini-instruct\",\n", " \"model_id\": \"microsoft/Phi-3.5-mini-instruct\",\n",
" \"description\": \"3.8B model - with reasoning capabilities\",\n", " \"description\": \"Reasoning capabilities\",\n",
" \"size\": \"3.8B\",\n", " \"size\": \"3.8B\",\n",
" \"type\": \"huggingface\"\n", " \"type\": \"huggingface\",\n",
" \"model_class\": \"Phi3ForCausalLM\"\n",
" },\n", " },\n",
" \"Gemma 2 9B\": {\n", " \"Gemma 2 9B\": {\n",
" \"model_id\": \"google/gemma-2-9b-it\",\n", " \"model_id\": \"google/gemma-2-9b-it\",\n",
" \"description\": \"9B model - instruction-tuned model\",\n", " \"description\": \"Instruction-tuned model\",\n",
" \"size\": \"9B\",\n", " \"size\": \"9B\",\n",
" \"type\": \"huggingface\"\n", " \"type\": \"huggingface\",\n",
" \"model_class\": \"GemmaForCausalLM\"\n",
" },\n", " },\n",
" \"Qwen 2.5 7B\": {\n", " \"Qwen 2.5 7B\": {\n",
" \"model_id\": \"Qwen/Qwen2.5-7B-Instruct\",\n", " \"model_id\": \"Qwen/Qwen2.5-7B-Instruct\",\n",
" \"description\": \"7B model - multilingual that is good for diverse data\",\n", " \"description\": \"Multilingual that is good for diverse data\",\n",
" \"size\": \"7B\",\n", " \"size\": \"7B\",\n",
" \"type\": \"huggingface\"\n", " \"type\": \"huggingface\",\n",
" \"model_class\": \"Qwen2ForCausalLM\"\n",
" },\n", " },\n",
" \"Mistral 7B\": {\n", " \"Mistral 7B\": {\n",
" \"model_id\": \"mistralai/Mistral-7B-Instruct-v0.3\",\n", " \"model_id\": \"mistralai/Mistral-7B-Instruct-v0.3\",\n",
" \"description\": \"7B model - fast inference\",\n", " \"description\": \"Fast inference\",\n",
" \"size\": \"7B\",\n", " \"size\": \"7B\",\n",
" \"type\": \"huggingface\"\n", " \"type\": \"huggingface\",\n",
" \"model_class\": \"MistralForCausalLM\"\n",
" },\n", " },\n",
" \"Zephyr 7B\": {\n", " \"Zephyr 7B\": {\n",
" \"model_id\": \"HuggingFaceH4/zephyr-7b-beta\",\n", " \"model_id\": \"HuggingFaceH4/zephyr-7b-beta\",\n",
" \"description\": \"7B model - fine-tuned for instruction following\",\n", " \"description\": \"Fine-tuned for instruction following\",\n",
" \"size\": \"7B\",\n", " \"size\": \"7B\",\n",
" \"type\": \"huggingface\"\n", " \"type\": \"huggingface\",\n",
" \"model_class\": \"ZephyrForCausalLM\"\n",
" }\n", " }\n",
"}\n", "}\n",
"\n", "\n",
@@ -305,15 +308,15 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"source": [ "execution_count": null,
"schema_manager.generate_schema_with_llm(\"realstate dataset for residential houses\",'Gemini 2.5 Flash', 0.7)" "id": "dFYWA5y0ZmJr",
],
"metadata": { "metadata": {
"id": "dFYWA5y0ZmJr" "id": "dFYWA5y0ZmJr"
}, },
"id": "dFYWA5y0ZmJr", "outputs": [],
"execution_count": null, "source": [
"outputs": [] "schema_manager.generate_schema_with_llm(\"realstate dataset for residential houses\",'Gemini 2.5 Flash', 0.7)"
]
}, },
{ {
"cell_type": "code", "cell_type": "code",
@@ -460,6 +463,63 @@
"print(\"✅ Schema Management Module loaded!\")\n" "print(\"✅ Schema Management Module loaded!\")\n"
] ]
}, },
{
"cell_type": "code",
"execution_count": null,
"id": "52c7cb55",
"metadata": {},
"outputs": [],
"source": [
"# Fixed HuggingFace Model Loading\n",
"def load_huggingface_model_with_correct_class(model_id, model_class_name, quantization_config, torch_dtype):\n",
" \"\"\"Load HuggingFace model with correct model class\"\"\"\n",
" try:\n",
" # Import the specific model class\n",
" if model_class_name == \"LlamaForCausalLM\":\n",
" from transformers import LlamaForCausalLM\n",
" model_class = LlamaForCausalLM\n",
" elif model_class_name == \"Phi3ForCausalLM\":\n",
" from transformers import Phi3ForCausalLM\n",
" model_class = Phi3ForCausalLM\n",
" elif model_class_name == \"GemmaForCausalLM\":\n",
" from transformers import GemmaForCausalLM\n",
" model_class = GemmaForCausalLM\n",
" elif model_class_name == \"Qwen2ForCausalLM\":\n",
" from transformers import Qwen2ForCausalLM\n",
" model_class = Qwen2ForCausalLM\n",
" elif model_class_name == \"MistralForCausalLM\":\n",
" from transformers import MistralForCausalLM\n",
" model_class = MistralForCausalLM\n",
" else:\n",
" # Fallback to AutoModelForCausalLM\n",
" model_class = AutoModelForCausalLM\n",
" \n",
" # Load the model\n",
" model = model_class.from_pretrained(\n",
" model_id,\n",
" device_map=\"auto\",\n",
" quantization_config=quantization_config,\n",
" torch_dtype=torch_dtype\n",
" )\n",
" return model\n",
" \n",
" except Exception as e:\n",
" print(f\"Error loading {model_class_name}: {str(e)}\")\n",
" # Fallback to AutoModelForCausalLM\n",
" try:\n",
" model = AutoModelForCausalLM.from_pretrained(\n",
" model_id,\n",
" device_map=\"auto\",\n",
" quantization_config=quantization_config,\n",
" torch_dtype=torch_dtype\n",
" )\n",
" return model\n",
" except Exception as e2:\n",
" raise Exception(f\"Failed to load model with both specific and auto classes: {str(e2)}\")\n",
"\n",
"print(\"✅ Fixed HuggingFace model loading function created!\")\n"
]
},
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
@@ -558,12 +618,13 @@
" tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)\n", " tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)\n",
" tokenizer.pad_token = tokenizer.eos_token\n", " tokenizer.pad_token = tokenizer.eos_token\n",
"\n", "\n",
" # Load model with quantization\n", " # Load model with quantization using correct model class\n",
" model = AutoModelForCausalLM.from_pretrained(\n", " model_class_name = model_info.get(\"model_class\", \"AutoModelForCausalLM\")\n",
" model_id,\n", " model = load_huggingface_model_with_correct_class(\n",
" device_map=\"auto\",\n", " model_id, \n",
" quantization_config=self.quantization_config,\n", " model_class_name, \n",
" torch_dtype=torch.bfloat16\n", " self.quantization_config, \n",
" torch.bfloat16\n",
" )\n", " )\n",
"\n", "\n",
" self.loaded_models[model_name] = {\n", " self.loaded_models[model_name] = {\n",
@@ -674,6 +735,596 @@
"print(\"✅ Dataset Generation Module loaded!\")\n" "print(\"✅ Dataset Generation Module loaded!\")\n"
] ]
}, },
{
"cell_type": "code",
"execution_count": null,
"id": "d5d8d07f",
"metadata": {},
"outputs": [],
"source": [
"# Fixed Schema Generation for HuggingFace Models\n",
"def fixed_schema_generation_huggingface(model_name: str, system_prompt: str, user_prompt: str, temperature: float) -> str:\n",
" \"\"\"Fixed HuggingFace schema generation\"\"\"\n",
" model_info = HUGGINGFACE_MODELS[model_name]\n",
" model_id = model_info[\"model_id\"]\n",
"\n",
" try:\n",
" # Check if model is already loaded\n",
" if model_name not in dataset_generator.loaded_models:\n",
" print(f\"🔄 Loading {model_name} for schema generation...\")\n",
"\n",
" # Load tokenizer\n",
" tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)\n",
" tokenizer.pad_token = tokenizer.eos_token\n",
"\n",
" # Load model with quantization using correct model class\n",
" model_class_name = model_info.get(\"model_class\", \"AutoModelForCausalLM\")\n",
" model = load_huggingface_model_with_correct_class(\n",
" model_id, \n",
" model_class_name, \n",
" dataset_generator.quantization_config, \n",
" torch.bfloat16\n",
" )\n",
"\n",
" dataset_generator.loaded_models[model_name] = {\n",
" 'model': model,\n",
" 'tokenizer': tokenizer\n",
" }\n",
" print(f\"✅ {model_name} loaded successfully for schema generation!\")\n",
"\n",
" # Get model and tokenizer\n",
" model = dataset_generator.loaded_models[model_name]['model']\n",
" tokenizer = dataset_generator.loaded_models[model_name]['tokenizer']\n",
"\n",
" # Prepare messages\n",
" messages = [\n",
" {\"role\": \"system\", \"content\": system_prompt},\n",
" {\"role\": \"user\", \"content\": user_prompt}\n",
" ]\n",
"\n",
" # Tokenize\n",
" inputs = tokenizer.apply_chat_template(messages, return_tensors=\"pt\").to(\"cuda\")\n",
"\n",
" # Generate\n",
" with torch.no_grad():\n",
" outputs = model.generate(\n",
" inputs,\n",
" max_new_tokens=2000,\n",
" temperature=temperature,\n",
" do_sample=True,\n",
" pad_token_id=tokenizer.eos_token_id\n",
" )\n",
"\n",
" # Decode response\n",
" response = tokenizer.decode(outputs[0], skip_special_tokens=True)\n",
"\n",
" # Extract only the assistant's response\n",
" if \"<|assistant|>\" in response:\n",
" response = response.split(\"<|assistant|>\")[-1].strip()\n",
" elif \"assistant\" in response:\n",
" response = response.split(\"assistant\")[-1].strip()\n",
"\n",
" return response\n",
"\n",
" except Exception as e:\n",
" # Clean up on error\n",
" if model_name in dataset_generator.loaded_models:\n",
" del dataset_generator.loaded_models[model_name]\n",
" gc.collect()\n",
" torch.cuda.empty_cache()\n",
" raise Exception(f\"HuggingFace schema generation error: {str(e)}\")\n",
"\n",
"print(\"✅ Fixed HuggingFace schema generation function created!\")\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "1fd4b8c8",
"metadata": {},
"outputs": [],
"source": [
"# Fixed File Download for Google Colab\n",
"import io\n",
"from google.colab import files\n",
"\n",
"def save_dataset_colab(records: List[Dict], file_format: str, filename: str) -> str:\n",
" \"\"\"Save dataset and trigger download in Google Colab\"\"\"\n",
" if not records:\n",
" return \"❌ Error: No data to export\"\n",
"\n",
" try:\n",
" # Ensure filename has correct extension\n",
" if not filename.endswith(file_format):\n",
" filename += file_format\n",
"\n",
" # Create DataFrame\n",
" df = pd.DataFrame(records)\n",
"\n",
" if file_format == \".csv\":\n",
" csv_buffer = io.StringIO()\n",
" df.to_csv(csv_buffer, index=False)\n",
" csv_data = csv_buffer.getvalue()\n",
" files.download(io.BytesIO(csv_data.encode()), filename)\n",
" \n",
" elif file_format == \".tsv\":\n",
" tsv_buffer = io.StringIO()\n",
" df.to_csv(tsv_buffer, sep=\"\\t\", index=False)\n",
" tsv_data = tsv_buffer.getvalue()\n",
" files.download(io.BytesIO(tsv_data.encode()), filename)\n",
" \n",
" elif file_format == \".json\":\n",
" json_data = df.to_json(orient=\"records\", indent=2)\n",
" files.download(io.BytesIO(json_data.encode()), filename)\n",
" \n",
" elif file_format == \".jsonl\":\n",
" jsonl_data = \"\\n\".join([json.dumps(record) for record in records])\n",
" files.download(io.BytesIO(jsonl_data.encode()), filename)\n",
" else:\n",
" return f\"❌ Error: Unsupported format {file_format}\"\n",
"\n",
" return f\"✅ Dataset downloaded as {filename} ({len(records)} records)\"\n",
"\n",
" except Exception as e:\n",
" return f\"❌ Error saving dataset: {str(e)}\"\n",
"\n",
"def save_with_scores_colab(records: List[Dict], scores: List[int], file_format: str, filename: str) -> str:\n",
" \"\"\"Save dataset with quality scores and trigger download in Google Colab\"\"\"\n",
" if not records or not scores:\n",
" return \"❌ Error: No data or scores to export\"\n",
"\n",
" try:\n",
" # Add scores to records\n",
" records_with_scores = []\n",
" for i, record in enumerate(records):\n",
" record_with_score = record.copy()\n",
" record_with_score['quality_score'] = scores[i] if i < len(scores) else 0\n",
" records_with_scores.append(record_with_score)\n",
"\n",
" return save_dataset_colab(records_with_scores, file_format, filename)\n",
"\n",
" except Exception as e:\n",
" return f\"❌ Error saving dataset with scores: {str(e)}\"\n",
"\n",
"print(\"✅ Fixed file download functions for Google Colab created!\")\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "2e94ff70",
"metadata": {},
"outputs": [],
"source": [
"# Fixed UI Functions with Schema Flow and File Download\n",
"def generate_schema_fixed(business_case, schema_mode, schema_text, model_name, temperature):\n",
" \"\"\"Generate or enhance schema based on mode - FIXED VERSION\"\"\"\n",
" global current_schema_text\n",
" \n",
" if schema_mode == \"LLM Generate\":\n",
" if model_name in HUGGINGFACE_MODELS:\n",
" result = fixed_schema_generation_huggingface(\n",
" business_case, \n",
" \"You are an expert data scientist. Given a business case, generate a comprehensive dataset schema. Return the schema in this exact format: field_name (TYPE) - Description, example: example_value. Include 8-12 relevant fields that would be useful for the business case. Use realistic field names and appropriate data types (TEXT, INT, FLOAT, BOOLEAN, ARRAY). Provide clear descriptions and realistic examples.\",\n",
" f\"Business case: {business_case}\\n\\nGenerate a dataset schema for this business case. Include fields that would be relevant for analysis and decision-making.\",\n",
" temperature\n",
" )\n",
" else:\n",
" result = schema_manager.generate_schema_with_llm(business_case, model_name, temperature)\n",
" current_schema_text = result\n",
" return result, result\n",
" elif schema_mode == \"LLM Enhance Manual\":\n",
" if model_name in HUGGINGFACE_MODELS:\n",
" result = fixed_schema_generation_huggingface(\n",
" business_case,\n",
" \"You are an expert data scientist. Given a partial schema and business case, enhance it by: 1. Adding missing relevant fields 2. Improving field descriptions 3. Adding realistic examples 4. Ensuring proper data types. Return the enhanced schema in the same format as the original.\",\n",
" f\"Business case: {business_case}\\n\\nCurrent partial schema:\\n{schema_text}\\n\\nPlease enhance this schema by adding missing fields and improving the existing ones.\",\n",
" temperature\n",
" )\n",
" else:\n",
" result = schema_manager.enhance_schema_with_llm(schema_text, business_case, model_name, temperature)\n",
" current_schema_text = result\n",
" return result, result\n",
" else: # Manual Entry\n",
" current_schema_text = schema_text\n",
" return schema_text, schema_text\n",
"\n",
"def export_dataset_fixed(file_format, filename, include_scores):\n",
" \"\"\"Export dataset to specified format - FIXED VERSION for Google Colab\"\"\"\n",
" global current_dataset, current_scores\n",
"\n",
" if not current_dataset:\n",
" return \"No dataset to export\"\n",
"\n",
" try:\n",
" if include_scores and current_scores:\n",
" result = save_with_scores_colab(current_dataset, current_scores, file_format, filename)\n",
" else:\n",
" result = save_dataset_colab(current_dataset, file_format, filename)\n",
" return result\n",
" except Exception as e:\n",
" return f\"❌ Error exporting dataset: {str(e)}\"\n",
"\n",
"print(\"✅ Fixed UI functions with schema flow and file download created!\")\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "a8e47887",
"metadata": {},
"outputs": [],
"source": [
"# Updated Gradio Interface with Fixed Functions\n",
"def create_fixed_gradio_interface():\n",
" \"\"\"Create the main Gradio interface with 5 tabs - FIXED VERSION\"\"\"\n",
"\n",
" # Combine all models for dropdowns\n",
" all_models = list(HUGGINGFACE_MODELS.keys()) + list(COMMERCIAL_MODELS.keys())\n",
"\n",
" with gr.Blocks(title=\"Synthetic Dataset Generator\", theme=gr.themes.Soft()) as interface:\n",
"\n",
" gr.Markdown(\"# 🧬 Synthetic Dataset Generator with Quality Scoring\")\n",
" gr.Markdown(\"Generate realistic synthetic datasets using multiple LLM models with flexible schema creation, synonym permutation, and automated quality scoring.\")\n",
"\n",
" # Status bar\n",
" with gr.Row():\n",
" gpu_status = gr.Textbox(\n",
" label=\"GPU Status\",\n",
" value=dataset_generator.get_memory_usage(),\n",
" interactive=False,\n",
" scale=1\n",
" )\n",
" current_status = gr.Textbox(\n",
" label=\"Current Status\",\n",
" value=\"Ready to generate datasets\",\n",
" interactive=False,\n",
" scale=2\n",
" )\n",
"\n",
" # Tab 1: Schema Definition\n",
" with gr.Tab(\"📋 Schema Definition\"):\n",
" gr.Markdown(\"### Define your dataset schema\")\n",
"\n",
" with gr.Row():\n",
" with gr.Column(scale=2):\n",
" schema_mode = gr.Radio(\n",
" choices=[\"LLM Generate\", \"Manual Entry\", \"LLM Enhance Manual\"],\n",
" value=\"Manual Entry\",\n",
" label=\"Schema Mode\"\n",
" )\n",
"\n",
" business_case_input = gr.Textbox(\n",
" label=\"Business Case\",\n",
" value=current_business_case,\n",
" lines=3,\n",
" placeholder=\"Describe your business case or data requirements...\"\n",
" )\n",
"\n",
" schema_input = gr.Textbox(\n",
" label=\"Schema Definition\",\n",
" value=current_schema_text,\n",
" lines=15,\n",
" placeholder=\"Define your dataset schema here...\"\n",
" )\n",
"\n",
" with gr.Row():\n",
" schema_model = gr.Dropdown(\n",
" choices=all_models,\n",
" value=all_models[0],\n",
" label=\"Model for Schema Generation\"\n",
" )\n",
" schema_temperature = gr.Slider(\n",
" minimum=0.0,\n",
" maximum=2.0,\n",
" value=0.7,\n",
" step=0.1,\n",
" label=\"Temperature\"\n",
" )\n",
"\n",
" generate_schema_btn = gr.Button(\"🔄 Generate/Enhance Schema\", variant=\"primary\")\n",
"\n",
" with gr.Column(scale=1):\n",
" schema_output = gr.Textbox(\n",
" label=\"Generated Schema\",\n",
" lines=15,\n",
" interactive=False\n",
" )\n",
"\n",
" # Tab 2: Dataset Generation\n",
" with gr.Tab(\"🚀 Dataset Generation\"):\n",
" gr.Markdown(\"### Generate synthetic dataset\")\n",
"\n",
" with gr.Row():\n",
" with gr.Column(scale=2):\n",
" generation_schema = gr.Textbox(\n",
" label=\"Schema (from Tab 1)\",\n",
" value=current_schema_text,\n",
" lines=8,\n",
" interactive=False\n",
" )\n",
"\n",
" generation_business_case = gr.Textbox(\n",
" label=\"Business Case\",\n",
" value=current_business_case,\n",
" lines=2\n",
" )\n",
"\n",
" examples_input = gr.Textbox(\n",
" label=\"Few-shot Examples (JSON format)\",\n",
" lines=5,\n",
" placeholder='[{\"instruction\": \"example\", \"response\": \"example\"}]',\n",
" value=\"\"\n",
" )\n",
"\n",
" with gr.Row():\n",
" generation_model = gr.Dropdown(\n",
" choices=all_models,\n",
" value=all_models[0],\n",
" label=\"Generation Model\"\n",
" )\n",
" generation_temperature = gr.Slider(\n",
" minimum=0.0,\n",
" maximum=2.0,\n",
" value=0.7,\n",
" step=0.1,\n",
" label=\"Temperature\"\n",
" )\n",
" num_records = gr.Number(\n",
" value=50,\n",
" minimum=11,\n",
" maximum=1000,\n",
" step=1,\n",
" label=\"Number of Records\"\n",
" )\n",
"\n",
" generate_dataset_btn = gr.Button(\"🚀 Generate Dataset\", variant=\"primary\", size=\"lg\")\n",
"\n",
" with gr.Column(scale=1):\n",
" generation_status = gr.Textbox(\n",
" label=\"Generation Status\",\n",
" lines=3,\n",
" interactive=False\n",
" )\n",
"\n",
" dataset_preview = gr.Dataframe(\n",
" label=\"Dataset Preview (First 20 rows)\",\n",
" interactive=False,\n",
" wrap=True\n",
" )\n",
"\n",
" record_count = gr.Number(\n",
" label=\"Total Records Generated\",\n",
" interactive=False\n",
" )\n",
"\n",
" # Tab 3: Synonym Permutation\n",
" with gr.Tab(\"🔄 Synonym Permutation\"):\n",
" gr.Markdown(\"### Add diversity with synonym replacement\")\n",
"\n",
" with gr.Row():\n",
" with gr.Column(scale=2):\n",
" enable_permutation = gr.Checkbox(\n",
" label=\"Enable Synonym Permutation\",\n",
" value=False\n",
" )\n",
"\n",
" fields_to_permute = gr.CheckboxGroup(\n",
" label=\"Fields to Permute\",\n",
" choices=[],\n",
" value=[]\n",
" )\n",
"\n",
" permutation_rate = gr.Slider(\n",
" minimum=0,\n",
" maximum=50,\n",
" value=20,\n",
" step=5,\n",
" label=\"Permutation Rate (%)\"\n",
" )\n",
"\n",
" apply_permutation_btn = gr.Button(\"🔄 Apply Permutation\", variant=\"secondary\")\n",
"\n",
" with gr.Column(scale=1):\n",
" permutation_status = gr.Textbox(\n",
" label=\"Permutation Status\",\n",
" lines=2,\n",
" interactive=False\n",
" )\n",
"\n",
" permuted_preview = gr.Dataframe(\n",
" label=\"Permuted Dataset Preview\",\n",
" interactive=False,\n",
" wrap=True\n",
" )\n",
"\n",
" # Tab 4: Quality Scoring\n",
" with gr.Tab(\"📊 Quality Scoring\"):\n",
" gr.Markdown(\"### Evaluate dataset quality\")\n",
"\n",
" with gr.Row():\n",
" with gr.Column(scale=2):\n",
" scoring_model = gr.Dropdown(\n",
" choices=all_models,\n",
" value=all_models[0],\n",
" label=\"Scoring Model\"\n",
" )\n",
"\n",
" scoring_temperature = gr.Slider(\n",
" minimum=0.0,\n",
" maximum=2.0,\n",
" value=0.3,\n",
" step=0.1,\n",
" label=\"Temperature\"\n",
" )\n",
"\n",
" score_dataset_btn = gr.Button(\"📊 Score Dataset Quality\", variant=\"primary\")\n",
"\n",
" with gr.Column(scale=1):\n",
" scoring_status = gr.Textbox(\n",
" label=\"Scoring Status\",\n",
" lines=2,\n",
" interactive=False\n",
" )\n",
"\n",
" scores_dataframe = gr.Dataframe(\n",
" label=\"Quality Scores\",\n",
" interactive=False\n",
" )\n",
"\n",
" quality_report = gr.JSON(\n",
" label=\"Quality Report\"\n",
" )\n",
"\n",
" # Tab 5: Export\n",
" with gr.Tab(\"💾 Export\"):\n",
" gr.Markdown(\"### Export your dataset\")\n",
"\n",
" with gr.Row():\n",
" with gr.Column(scale=2):\n",
" file_format = gr.Dropdown(\n",
" choices=OUTPUT_FORMATS,\n",
" value=\".csv\",\n",
" label=\"File Format\"\n",
" )\n",
"\n",
" filename = gr.Textbox(\n",
" label=\"Filename\",\n",
" value=\"synthetic_dataset\",\n",
" placeholder=\"Enter filename (extension added automatically)\"\n",
" )\n",
"\n",
" include_scores = gr.Checkbox(\n",
" label=\"Include Quality Scores\",\n",
" value=False\n",
" )\n",
"\n",
" export_btn = gr.Button(\"💾 Export Dataset\", variant=\"primary\")\n",
"\n",
" with gr.Column(scale=1):\n",
" export_status = gr.Textbox(\n",
" label=\"Export Status\",\n",
" lines=3,\n",
" interactive=False\n",
" )\n",
"\n",
" # Event handlers - FIXED VERSION\n",
" generate_schema_btn.click(\n",
" generate_schema_fixed,\n",
" inputs=[business_case_input, schema_mode, schema_input, schema_model, schema_temperature],\n",
" outputs=[schema_output, schema_input, generation_schema]\n",
" )\n",
"\n",
" generate_dataset_btn.click(\n",
" generate_dataset_ui,\n",
" inputs=[generation_schema, generation_business_case, generation_model, generation_temperature, num_records, examples_input],\n",
" outputs=[generation_status, dataset_preview, record_count]\n",
" )\n",
"\n",
" apply_permutation_btn.click(\n",
" apply_synonym_permutation,\n",
" inputs=[enable_permutation, fields_to_permute, permutation_rate],\n",
" outputs=[permuted_preview, permutation_status]\n",
" )\n",
"\n",
" score_dataset_btn.click(\n",
" score_dataset_quality,\n",
" inputs=[scoring_model, scoring_temperature],\n",
" outputs=[scoring_status, scores_dataframe, quality_report]\n",
" )\n",
"\n",
" export_btn.click(\n",
" export_dataset_fixed,\n",
" inputs=[file_format, filename, include_scores],\n",
" outputs=[export_status]\n",
" )\n",
"\n",
" # Update field choices when dataset is generated\n",
" def update_field_choices():\n",
" fields = get_available_fields()\n",
" return gr.CheckboxGroup(choices=fields, value=[])\n",
"\n",
" # Auto-update field choices\n",
" generate_dataset_btn.click(\n",
" update_field_choices,\n",
" outputs=[fields_to_permute]\n",
" )\n",
"\n",
" return interface\n",
"\n",
"print(\"✅ Fixed Gradio Interface created!\")\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "469152ce",
"metadata": {},
"outputs": [],
"source": [
"# Launch the Fixed Gradio Interface\n",
"print(\"🚀 Launching Fixed Synthetic Dataset Generator...\")\n",
"interface = create_fixed_gradio_interface()\n",
"interface.launch(debug=True, share=True)\n"
]
},
{
"cell_type": "markdown",
"id": "42127e24",
"metadata": {},
"source": [
"## 🔧 Bug Fixes Applied\n",
"\n",
"### ✅ Issues Fixed:\n",
"\n",
"1. **Schema Flow Issue**: \n",
" - Fixed schema generation to properly pass generated schema to Dataset Generation tab\n",
" - Updated `generate_schema_fixed()` function to update global `current_schema_text`\n",
" - Added proper output connections in Gradio interface\n",
"\n",
"2. **File Download Issue**:\n",
" - Implemented Google Colab-compatible file download using `google.colab.files.download()`\n",
" - Created `save_dataset_colab()` and `save_with_scores_colab()` functions\n",
" - Files now download directly to browser instead of saving to local storage\n",
"\n",
"3. **HuggingFace Schema Generation**:\n",
" - Implemented `fixed_schema_generation_huggingface()` function\n",
" - Added proper model loading and inference for schema generation\n",
" - Integrated with existing schema management system\n",
"\n",
"4. **HuggingFace Model Import Issues**:\n",
" - Added correct model classes for each HuggingFace model:\n",
" - Llama models: `LlamaForCausalLM`\n",
" - Phi-3.5: `Phi3ForCausalLM`\n",
" - Gemma 2: `GemmaForCausalLM`\n",
" - Qwen 2.5: `Qwen2ForCausalLM`\n",
" - Mistral models: `MistralForCausalLM`\n",
" - Created `load_huggingface_model_with_correct_class()` function with fallback to `AutoModelForCausalLM`\n",
" - Updated model configuration with `model_class` field\n",
"\n",
"5. **Updated Dependencies**:\n",
" - Added `google-colab` package for proper Colab integration\n",
" - Fixed import issues for Google Colab environment\n",
"\n",
"### 🚀 How to Use the Fixed Version:\n",
"\n",
"1. **Run all cells in order** - the fixes are applied automatically\n",
"2. **Schema Tab**: Generate schema with any model (HuggingFace or Commercial)\n",
"3. **Dataset Tab**: Schema automatically flows from Tab 1\n",
"4. **Export Tab**: Files download directly to your browser\n",
"5. **All HuggingFace models** now work properly for both schema generation and dataset generation\n",
"\n",
"### 🔧 Technical Details:\n",
"\n",
"- **Model Loading**: Uses correct model classes with fallback to AutoModelForCausalLM\n",
"- **File Downloads**: Uses `google.colab.files.download()` for browser downloads\n",
"- **Schema Flow**: Global variables ensure schema passes between tabs\n",
"- **Error Handling**: Comprehensive error handling with model cleanup\n",
"- **Memory Management**: Proper GPU memory cleanup on errors\n",
"\n",
"The application should now work seamlessly in Google Colab with all features functional!\n"
]
},
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
@@ -1870,6 +2521,11 @@
} }
], ],
"metadata": { "metadata": {
"accelerator": "GPU",
"colab": {
"gpuType": "T4",
"provenance": []
},
"kernelspec": { "kernelspec": {
"display_name": "Python 3", "display_name": "Python 3",
"name": "python3" "name": "python3"
@@ -1885,13 +2541,8 @@
"nbconvert_exporter": "python", "nbconvert_exporter": "python",
"pygments_lexer": "ipython3", "pygments_lexer": "ipython3",
"version": "3.11.10" "version": "3.11.10"
}, }
"colab": {
"provenance": [],
"gpuType": "T4"
},
"accelerator": "GPU"
}, },
"nbformat": 4, "nbformat": 4,
"nbformat_minor": 5 "nbformat_minor": 5
} }