Merge pull request #799 from TheTopDeveloper/community-contributions-branch
Add Week 3 Synthetic Survey Dataset Generator with improved LLM generation - Joshua Oluoch (Andela GenAI Boocamp
This commit is contained in:
@@ -0,0 +1,906 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "a8dbb4e8",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# 🧪 Survey Synthetic Dataset Generator — Week 3 Task"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "8d86f629",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"\n",
|
||||
"import os, re, json, time, uuid, math, random\n",
|
||||
"from datetime import datetime, timedelta\n",
|
||||
"from typing import List, Dict, Any\n",
|
||||
"import numpy as np, pandas as pd\n",
|
||||
"import pandera.pandas as pa\n",
|
||||
"random.seed(7); np.random.seed(7)\n",
|
||||
"print(\"✅ Base libraries ready. Pandera available:\", pa is not None)\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "f196ae73",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"\n",
|
||||
"def extract_strict_json(text: str):\n",
|
||||
" \"\"\"Improved JSON extraction with multiple fallback strategies\"\"\"\n",
|
||||
" if text is None:\n",
|
||||
" raise ValueError(\"Empty model output.\")\n",
|
||||
" \n",
|
||||
" t = text.strip()\n",
|
||||
" \n",
|
||||
" # Strategy 1: Direct JSON parsing\n",
|
||||
" try:\n",
|
||||
" obj = json.loads(t)\n",
|
||||
" if isinstance(obj, list):\n",
|
||||
" return obj\n",
|
||||
" elif isinstance(obj, dict):\n",
|
||||
" for key in (\"rows\",\"data\",\"items\",\"records\",\"results\"):\n",
|
||||
" if key in obj and isinstance(obj[key], list):\n",
|
||||
" return obj[key]\n",
|
||||
" if all(isinstance(k, str) and k.isdigit() for k in obj.keys()):\n",
|
||||
" return [obj[k] for k in sorted(obj.keys(), key=int)]\n",
|
||||
" except json.JSONDecodeError:\n",
|
||||
" pass\n",
|
||||
" \n",
|
||||
" # Strategy 2: Extract JSON from code blocks\n",
|
||||
" if t.startswith(\"```\"):\n",
|
||||
" t = re.sub(r\"^```(?:json)?\\s*|\\s*```$\", \"\", t, flags=re.IGNORECASE|re.MULTILINE).strip()\n",
|
||||
" \n",
|
||||
" # Strategy 3: Find JSON array in text\n",
|
||||
" start, end = t.find('['), t.rfind(']')\n",
|
||||
" if start == -1 or end == -1 or end <= start:\n",
|
||||
" raise ValueError(\"No JSON array found in model output.\")\n",
|
||||
" \n",
|
||||
" t = t[start:end+1]\n",
|
||||
" \n",
|
||||
" # Strategy 4: Fix common JSON issues\n",
|
||||
" t = re.sub(r\",\\s*([\\]}])\", r\"\\1\", t) # Remove trailing commas\n",
|
||||
" t = re.sub(r\"\\bNaN\\b|\\bInfinity\\b|\\b-Infinity\\b\", \"null\", t) # Replace NaN/Infinity\n",
|
||||
" t = t.replace(\"\\u00a0\", \" \").replace(\"\\u200b\", \"\") # Remove invisible characters\n",
|
||||
" \n",
|
||||
" try:\n",
|
||||
" return json.loads(t)\n",
|
||||
" except json.JSONDecodeError as e:\n",
|
||||
" raise ValueError(f\"Could not parse JSON: {str(e)}. Text: {t[:200]}...\")\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "3670fa0d",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## 1) Configuration"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "d16bd03a",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"\n",
|
||||
"CFG = {\n",
|
||||
" \"rows\": 800,\n",
|
||||
" \"datetime_range\": {\"start\": \"2024-01-01\", \"end\": \"2025-10-01\", \"fmt\": \"%Y-%m-%d %H:%M:%S\"},\n",
|
||||
" \"fields\": [\n",
|
||||
" {\"name\": \"response_id\", \"type\": \"uuid4\"},\n",
|
||||
" {\"name\": \"respondent_id\", \"type\": \"int\", \"min\": 10000, \"max\": 99999},\n",
|
||||
" {\"name\": \"submitted_at\", \"type\": \"datetime\"},\n",
|
||||
" {\"name\": \"country\", \"type\": \"enum\", \"values\": [\"KE\",\"UG\",\"TZ\",\"RW\",\"NG\",\"ZA\"], \"probs\": [0.50,0.10,0.12,0.05,0.15,0.08]},\n",
|
||||
" {\"name\": \"language\", \"type\": \"enum\", \"values\": [\"en\",\"sw\"], \"probs\": [0.85,0.15]},\n",
|
||||
" {\"name\": \"device\", \"type\": \"enum\", \"values\": [\"android\",\"ios\",\"web\"], \"probs\": [0.60,0.25,0.15]},\n",
|
||||
" {\"name\": \"age\", \"type\": \"int\", \"min\": 18, \"max\": 70},\n",
|
||||
" {\"name\": \"gender\", \"type\": \"enum\", \"values\": [\"female\",\"male\",\"nonbinary\",\"prefer_not_to_say\"], \"probs\": [0.49,0.49,0.01,0.01]},\n",
|
||||
" {\"name\": \"education\", \"type\": \"enum\", \"values\": [\"primary\",\"secondary\",\"diploma\",\"bachelor\",\"postgraduate\"], \"probs\": [0.08,0.32,0.18,0.30,0.12]},\n",
|
||||
" {\"name\": \"income_band\", \"type\": \"enum\", \"values\": [\"low\",\"lower_mid\",\"upper_mid\",\"high\"], \"probs\": [0.28,0.42,0.23,0.07]},\n",
|
||||
" {\"name\": \"completion_seconds\", \"type\": \"float\", \"min\": 60, \"max\": 1800, \"distribution\": \"lognormal\"},\n",
|
||||
" {\"name\": \"attention_passed\", \"type\": \"bool\"},\n",
|
||||
" {\"name\": \"q_quality\", \"type\": \"int\", \"min\": 1, \"max\": 5},\n",
|
||||
" {\"name\": \"q_value\", \"type\": \"int\", \"min\": 1, \"max\": 5},\n",
|
||||
" {\"name\": \"q_ease\", \"type\": \"int\", \"min\": 1, \"max\": 5},\n",
|
||||
" {\"name\": \"q_support\", \"type\": \"int\", \"min\": 1, \"max\": 5},\n",
|
||||
" {\"name\": \"nps\", \"type\": \"int\", \"min\": 0, \"max\": 10},\n",
|
||||
" {\"name\": \"is_detractor\", \"type\": \"bool\"}\n",
|
||||
" ]\n",
|
||||
"}\n",
|
||||
"print(\"Loaded config for\", CFG[\"rows\"], \"rows and\", len(CFG[\"fields\"]), \"fields.\")\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "7da1f429",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## 2) Helpers"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "d2f5fdff",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"\n",
|
||||
"def sample_enum(values, probs=None, size=None):\n",
|
||||
" values = list(values)\n",
|
||||
" if probs is None:\n",
|
||||
" probs = [1.0 / len(values)] * len(values)\n",
|
||||
" return np.random.choice(values, p=probs, size=size)\n",
|
||||
"\n",
|
||||
"def sample_numeric(field_cfg, size=1):\n",
|
||||
" t = field_cfg[\"type\"]\n",
|
||||
" if t == \"int\":\n",
|
||||
" lo, hi = int(field_cfg[\"min\"]), int(field_cfg[\"max\"])\n",
|
||||
" dist = field_cfg.get(\"distribution\", \"uniform\")\n",
|
||||
" if dist == \"uniform\":\n",
|
||||
" return np.random.randint(lo, hi + 1, size=size)\n",
|
||||
" elif dist == \"normal\":\n",
|
||||
" mu = (lo + hi) / 2.0\n",
|
||||
" sigma = (hi - lo) / 6.0\n",
|
||||
" out = np.random.normal(mu, sigma, size=size)\n",
|
||||
" return np.clip(out, lo, hi).astype(int)\n",
|
||||
" else:\n",
|
||||
" return np.random.randint(lo, hi + 1, size=size)\n",
|
||||
" elif t == \"float\":\n",
|
||||
" lo, hi = float(field_cfg[\"min\"]), float(field_cfg[\"max\"])\n",
|
||||
" dist = field_cfg.get(\"distribution\", \"uniform\")\n",
|
||||
" if dist == \"uniform\":\n",
|
||||
" return np.random.uniform(lo, hi, size=size)\n",
|
||||
" elif dist == \"normal\":\n",
|
||||
" mu = (lo + hi) / 2.0\n",
|
||||
" sigma = (hi - lo) / 6.0\n",
|
||||
" return np.clip(np.random.normal(mu, sigma, size=size), lo, hi)\n",
|
||||
" elif dist == \"lognormal\":\n",
|
||||
" mu = math.log(max(1e-3, (lo + hi) / 2.0))\n",
|
||||
" sigma = 0.75\n",
|
||||
" out = np.random.lognormal(mu, sigma, size=size)\n",
|
||||
" return np.clip(out, lo, hi)\n",
|
||||
" else:\n",
|
||||
" return np.random.uniform(lo, hi, size=size)\n",
|
||||
" else:\n",
|
||||
" raise ValueError(\"Unsupported numeric type\")\n",
|
||||
"\n",
|
||||
"def sample_datetime(start: str, end: str, size=1, fmt=\"%Y-%m-%d %H:%M:%S\"):\n",
|
||||
" s = datetime.fromisoformat(start)\n",
|
||||
" e = datetime.fromisoformat(end)\n",
|
||||
" total = int((e - s).total_seconds())\n",
|
||||
" r = np.random.randint(0, total, size=size)\n",
|
||||
" return [(s + timedelta(seconds=int(x))).strftime(fmt) for x in r]\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "5f24111a",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## 3) Rule-based Generator"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "cd61330d",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"\n",
|
||||
"def generate_rule_based(CFG: Dict[str, Any]) -> pd.DataFrame:\n",
|
||||
" n = CFG[\"rows\"]\n",
|
||||
" dt_cfg = CFG.get(\"datetime_range\", {\"start\":\"2024-01-01\",\"end\":\"2025-10-01\",\"fmt\":\"%Y-%m-%d %H:%M:%S\"})\n",
|
||||
" data = {}\n",
|
||||
" for f in CFG[\"fields\"]:\n",
|
||||
" name, t = f[\"name\"], f[\"type\"]\n",
|
||||
" if t == \"uuid4\":\n",
|
||||
" data[name] = [str(uuid.uuid4()) for _ in range(n)]\n",
|
||||
" elif t in (\"int\",\"float\"):\n",
|
||||
" data[name] = sample_numeric(f, size=n)\n",
|
||||
" elif t == \"enum\":\n",
|
||||
" data[name] = sample_enum(f[\"values\"], f.get(\"probs\"), size=n)\n",
|
||||
" elif t == \"datetime\":\n",
|
||||
" data[name] = sample_datetime(dt_cfg[\"start\"], dt_cfg[\"end\"], size=n, fmt=dt_cfg[\"fmt\"])\n",
|
||||
" elif t == \"bool\":\n",
|
||||
" data[name] = np.random.rand(n) < 0.9 # 90% True\n",
|
||||
" else:\n",
|
||||
" data[name] = [None]*n\n",
|
||||
" df = pd.DataFrame(data)\n",
|
||||
"\n",
|
||||
" # Derive NPS roughly from likert questions\n",
|
||||
" if set([\"q_quality\",\"q_value\",\"q_ease\",\"q_support\"]).issubset(df.columns):\n",
|
||||
" likert_avg = df[[\"q_quality\",\"q_value\",\"q_ease\",\"q_support\"]].mean(axis=1)\n",
|
||||
" df[\"nps\"] = np.clip(np.round((likert_avg - 1.0) * (10.0/4.0) + np.random.normal(0, 1.2, size=n)), 0, 10).astype(int)\n",
|
||||
"\n",
|
||||
" # Heuristic target: is_detractor more likely when completion high & attention failed\n",
|
||||
" if \"is_detractor\" in df.columns:\n",
|
||||
" base = 0.25\n",
|
||||
" comp = df.get(\"completion_seconds\", pd.Series(np.zeros(n)))\n",
|
||||
" attn = pd.Series(df.get(\"attention_passed\", np.ones(n))).astype(bool)\n",
|
||||
" boost = (comp > 900).astype(int) + (~attn).astype(int)\n",
|
||||
" p = np.clip(base + 0.15*boost, 0.01, 0.95)\n",
|
||||
" df[\"is_detractor\"] = np.random.rand(n) < p\n",
|
||||
"\n",
|
||||
" return df\n",
|
||||
"\n",
|
||||
"df_rule = generate_rule_based(CFG)\n",
|
||||
"df_rule.head()\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "dd9eff20",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## 4) Validation (Pandera optional)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "9a4ef86a",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"\n",
|
||||
"def build_pandera_schema(CFG):\n",
|
||||
" if pa is None:\n",
|
||||
" return None\n",
|
||||
" cols = {}\n",
|
||||
" for f in CFG[\"fields\"]:\n",
|
||||
" t, name = f[\"type\"], f[\"name\"]\n",
|
||||
" if t == \"int\": cols[name] = pa.Column(int)\n",
|
||||
" elif t == \"float\": cols[name] = pa.Column(float)\n",
|
||||
" elif t == \"enum\": cols[name] = pa.Column(object)\n",
|
||||
" elif t == \"datetime\": cols[name] = pa.Column(object)\n",
|
||||
" elif t == \"uuid4\": cols[name] = pa.Column(object)\n",
|
||||
" elif t == \"bool\": cols[name] = pa.Column(bool)\n",
|
||||
" else: cols[name] = pa.Column(object)\n",
|
||||
" return pa.DataFrameSchema(cols) if pa is not None else None\n",
|
||||
"\n",
|
||||
"def validate_df(df, CFG):\n",
|
||||
" schema = build_pandera_schema(CFG)\n",
|
||||
" if schema is None:\n",
|
||||
" return df, {\"engine\":\"basic\",\"valid_rows\": len(df), \"invalid_rows\": 0}\n",
|
||||
" try:\n",
|
||||
" v = schema.validate(df, lazy=True)\n",
|
||||
" return v, {\"engine\":\"pandera\",\"valid_rows\": len(v), \"invalid_rows\": 0}\n",
|
||||
" except Exception as e:\n",
|
||||
" print(\"Validation error:\", e)\n",
|
||||
" return df, {\"engine\":\"pandera\",\"valid_rows\": len(df), \"invalid_rows\": 0, \"notes\": \"Non-strict mode.\"}\n",
|
||||
"\n",
|
||||
"validated_rule, report_rule = validate_df(df_rule, CFG)\n",
|
||||
"print(report_rule)\n",
|
||||
"validated_rule.head()\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "d5f1d93a",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## 5) Save"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "73626b4c",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"\n",
|
||||
"from pathlib import Path\n",
|
||||
"out = Path(\"data\"); out.mkdir(exist_ok=True)\n",
|
||||
"ts = datetime.utcnow().strftime(\"%Y%m%dT%H%M%SZ\")\n",
|
||||
"csv_path = out / f\"survey_rule_{ts}.csv\"\n",
|
||||
"validated_rule.to_csv(csv_path, index=False)\n",
|
||||
"print(\"Saved:\", csv_path.as_posix())\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "87c89b51",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## 6) Optional: LLM Generator (JSON mode, retry & strict parsing)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "24e94771",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Fixed LLM Generation Functions\n",
|
||||
"def create_survey_prompt(CFG, n_rows=50):\n",
|
||||
" \"\"\"Create a clear, structured prompt for survey data generation\"\"\"\n",
|
||||
" fields_desc = []\n",
|
||||
" for field in CFG['fields']:\n",
|
||||
" name = field['name']\n",
|
||||
" field_type = field['type']\n",
|
||||
" \n",
|
||||
" if field_type == 'int':\n",
|
||||
" min_val = field.get('min', 0)\n",
|
||||
" max_val = field.get('max', 100)\n",
|
||||
" fields_desc.append(f\" - {name}: integer between {min_val} and {max_val}\")\n",
|
||||
" elif field_type == 'float':\n",
|
||||
" min_val = field.get('min', 0.0)\n",
|
||||
" max_val = field.get('max', 100.0)\n",
|
||||
" fields_desc.append(f\" - {name}: float between {min_val} and {max_val}\")\n",
|
||||
" elif field_type == 'enum':\n",
|
||||
" values = field.get('values', [])\n",
|
||||
" fields_desc.append(f\" - {name}: one of {values}\")\n",
|
||||
" elif field_type == 'bool':\n",
|
||||
" fields_desc.append(f\" - {name}: boolean (true/false)\")\n",
|
||||
" elif field_type == 'uuid4':\n",
|
||||
" fields_desc.append(f\" - {name}: UUID string\")\n",
|
||||
" elif field_type == 'datetime':\n",
|
||||
" fmt = field.get('fmt', '%Y-%m-%d %H:%M:%S')\n",
|
||||
" fields_desc.append(f\" - {name}: datetime string in format {fmt}\")\n",
|
||||
" else:\n",
|
||||
" fields_desc.append(f\" - {name}: {field_type}\")\n",
|
||||
" \n",
|
||||
" prompt = f\"\"\"Generate {n_rows} rows of realistic survey response data.\n",
|
||||
"\n",
|
||||
"Schema:\n",
|
||||
"{chr(10).join(fields_desc)}\n",
|
||||
"\n",
|
||||
"CRITICAL REQUIREMENTS:\n",
|
||||
"- Return a JSON object with a \"responses\" key containing an array\n",
|
||||
"- Each object in the array must have all required fields\n",
|
||||
"- Use realistic, diverse values for survey responses\n",
|
||||
"- No trailing commas\n",
|
||||
"- No comments or explanations\n",
|
||||
"\n",
|
||||
"Output format: JSON object with \"responses\" array containing exactly {n_rows} objects.\n",
|
||||
"\n",
|
||||
"Example structure:\n",
|
||||
"{{\n",
|
||||
" \"responses\": [\n",
|
||||
" {{\n",
|
||||
" \"response_id\": \"uuid-string\",\n",
|
||||
" \"respondent_id\": 12345,\n",
|
||||
" \"submitted_at\": \"2024-01-01 12:00:00\",\n",
|
||||
" \"country\": \"KE\",\n",
|
||||
" \"language\": \"en\",\n",
|
||||
" \"device\": \"android\",\n",
|
||||
" \"age\": 25,\n",
|
||||
" \"gender\": \"female\",\n",
|
||||
" \"education\": \"bachelor\",\n",
|
||||
" \"income_band\": \"upper_mid\",\n",
|
||||
" \"completion_seconds\": 300.5,\n",
|
||||
" \"attention_passed\": true,\n",
|
||||
" \"q_quality\": 4,\n",
|
||||
" \"q_value\": 3,\n",
|
||||
" \"q_ease\": 5,\n",
|
||||
" \"q_support\": 4,\n",
|
||||
" \"nps\": 8,\n",
|
||||
" \"is_detractor\": false\n",
|
||||
" }},\n",
|
||||
" ...\n",
|
||||
" ]\n",
|
||||
"}}\n",
|
||||
"\n",
|
||||
"IMPORTANT: Return ONLY the JSON object with \"responses\" key, nothing else.\"\"\"\n",
|
||||
" \n",
|
||||
" return prompt\n",
|
||||
"\n",
|
||||
"def repair_truncated_json(content):\n",
|
||||
" \"\"\"Attempt to repair truncated JSON responses\"\"\"\n",
|
||||
" content = content.strip()\n",
|
||||
" \n",
|
||||
" # If it starts with { but doesn't end with }, try to close it\n",
|
||||
" if content.startswith('{') and not content.endswith('}'):\n",
|
||||
" # Find the last complete object in the responses array\n",
|
||||
" responses_start = content.find('\"responses\": [')\n",
|
||||
" if responses_start != -1:\n",
|
||||
" # Find the last complete object\n",
|
||||
" brace_count = 0\n",
|
||||
" last_complete_pos = -1\n",
|
||||
" in_string = False\n",
|
||||
" escape_next = False\n",
|
||||
" \n",
|
||||
" for i, char in enumerate(content[responses_start:], responses_start):\n",
|
||||
" if escape_next:\n",
|
||||
" escape_next = False\n",
|
||||
" continue\n",
|
||||
" \n",
|
||||
" if char == '\\\\':\n",
|
||||
" escape_next = True\n",
|
||||
" continue\n",
|
||||
" \n",
|
||||
" if char == '\"' and not escape_next:\n",
|
||||
" in_string = not in_string\n",
|
||||
" continue\n",
|
||||
" \n",
|
||||
" if not in_string:\n",
|
||||
" if char == '{':\n",
|
||||
" brace_count += 1\n",
|
||||
" elif char == '}':\n",
|
||||
" brace_count -= 1\n",
|
||||
" if brace_count == 0:\n",
|
||||
" last_complete_pos = i\n",
|
||||
" break\n",
|
||||
" \n",
|
||||
" if last_complete_pos != -1:\n",
|
||||
" # Truncate at the last complete object and close the JSON\n",
|
||||
" repaired = content[:last_complete_pos + 1] + '\\n ]\\n}'\n",
|
||||
" print(f\"🔧 Repaired JSON: truncated at position {last_complete_pos}\")\n",
|
||||
" return repaired\n",
|
||||
" \n",
|
||||
" return content\n",
|
||||
"\n",
|
||||
"def fixed_llm_generate_batch(CFG, n_rows=50):\n",
|
||||
" \"\"\"Fixed LLM generation with better prompt and error handling\"\"\"\n",
|
||||
" if not os.getenv('OPENAI_API_KEY'):\n",
|
||||
" print(\"No OpenAI API key, using rule-based fallback\")\n",
|
||||
" tmp = dict(CFG); tmp['rows'] = n_rows\n",
|
||||
" return generate_rule_based(tmp)\n",
|
||||
" \n",
|
||||
" try:\n",
|
||||
" from openai import OpenAI\n",
|
||||
" client = OpenAI()\n",
|
||||
" \n",
|
||||
" prompt = create_survey_prompt(CFG, n_rows)\n",
|
||||
" \n",
|
||||
" print(f\"🔄 Generating {n_rows} survey responses with LLM...\")\n",
|
||||
" \n",
|
||||
" # Calculate appropriate max_tokens based on batch size\n",
|
||||
" # Roughly 200-300 tokens per row, with some buffer\n",
|
||||
" estimated_tokens = n_rows * 300 + 500 # Buffer for JSON structure\n",
|
||||
" max_tokens = min(max(estimated_tokens, 2000), 8000) # Between 2k-8k tokens\n",
|
||||
" \n",
|
||||
" print(f\"📊 Using max_tokens: {max_tokens} (estimated: {estimated_tokens})\")\n",
|
||||
" \n",
|
||||
" response = client.chat.completions.create(\n",
|
||||
" model='gpt-4o-mini',\n",
|
||||
" messages=[\n",
|
||||
" {'role': 'system', 'content': 'You are a data generation expert. Generate realistic survey data in JSON format. Always return complete, valid JSON.'},\n",
|
||||
" {'role': 'user', 'content': prompt}\n",
|
||||
" ],\n",
|
||||
" temperature=0.3,\n",
|
||||
" max_tokens=max_tokens,\n",
|
||||
" response_format={'type': 'json_object'}\n",
|
||||
" )\n",
|
||||
" \n",
|
||||
" content = response.choices[0].message.content\n",
|
||||
" print(f\"📝 Raw response length: {len(content)} characters\")\n",
|
||||
" \n",
|
||||
" # Check if response appears truncated\n",
|
||||
" if not content.strip().endswith('}') and not content.strip().endswith(']'):\n",
|
||||
" print(\"⚠️ Response appears truncated, attempting repair...\")\n",
|
||||
" content = repair_truncated_json(content)\n",
|
||||
" \n",
|
||||
" # Try to extract JSON with improved logic\n",
|
||||
" try:\n",
|
||||
" data = json.loads(content)\n",
|
||||
" print(f\"🔍 Parsed JSON type: {type(data)}\")\n",
|
||||
" \n",
|
||||
" if isinstance(data, list):\n",
|
||||
" df = pd.DataFrame(data)\n",
|
||||
" print(f\"📊 Direct array: {len(df)} rows\")\n",
|
||||
" elif isinstance(data, dict):\n",
|
||||
" # Check for common keys that might contain the data\n",
|
||||
" for key in ['responses', 'rows', 'data', 'items', 'records', 'results', 'survey_responses']:\n",
|
||||
" if key in data and isinstance(data[key], list):\n",
|
||||
" df = pd.DataFrame(data[key])\n",
|
||||
" print(f\"📊 Found data in '{key}': {len(df)} rows\")\n",
|
||||
" break\n",
|
||||
" else:\n",
|
||||
" # If no standard key found, check if all values are lists/objects\n",
|
||||
" list_keys = [k for k, v in data.items() if isinstance(v, list) and len(v) > 0]\n",
|
||||
" if list_keys:\n",
|
||||
" # Use the first list key found\n",
|
||||
" key = list_keys[0]\n",
|
||||
" df = pd.DataFrame(data[key])\n",
|
||||
" print(f\"📊 Found data in '{key}': {len(df)} rows\")\n",
|
||||
" else:\n",
|
||||
" # Try to convert the dict values to a list\n",
|
||||
" if all(isinstance(v, dict) for v in data.values()):\n",
|
||||
" df = pd.DataFrame(list(data.values()))\n",
|
||||
" print(f\"📊 Converted dict values: {len(df)} rows\")\n",
|
||||
" else:\n",
|
||||
" raise ValueError(f\"Unexpected JSON structure: {list(data.keys())}\")\n",
|
||||
" else:\n",
|
||||
" raise ValueError(f\"Unexpected JSON type: {type(data)}\")\n",
|
||||
" \n",
|
||||
" if len(df) == n_rows:\n",
|
||||
" print(f\"✅ Successfully generated {len(df)} survey responses\")\n",
|
||||
" return df\n",
|
||||
" else:\n",
|
||||
" print(f\"⚠️ Generated {len(df)} rows, expected {n_rows}\")\n",
|
||||
" if len(df) > 0:\n",
|
||||
" return df\n",
|
||||
" else:\n",
|
||||
" raise ValueError(\"No data generated\")\n",
|
||||
" \n",
|
||||
" except json.JSONDecodeError as e:\n",
|
||||
" print(f\"❌ JSON parsing failed: {str(e)}\")\n",
|
||||
" # Try the improved extract_strict_json function\n",
|
||||
" try:\n",
|
||||
" data = extract_strict_json(content)\n",
|
||||
" df = pd.DataFrame(data)\n",
|
||||
" print(f\"✅ Recovered with strict parsing: {len(df)} rows\")\n",
|
||||
" return df\n",
|
||||
" except Exception as e2:\n",
|
||||
" print(f\"❌ Strict parsing also failed: {str(e2)}\")\n",
|
||||
" # Print a sample of the content for debugging\n",
|
||||
" print(f\"🔍 Content sample: {content[:500]}...\")\n",
|
||||
" raise e2\n",
|
||||
" \n",
|
||||
" except Exception as e:\n",
|
||||
" print(f'❌ LLM error, fallback to rule-based mock: {str(e)}')\n",
|
||||
" tmp = dict(CFG); tmp['rows'] = n_rows\n",
|
||||
" return generate_rule_based(tmp)\n",
|
||||
"\n",
|
||||
"def fixed_generate_llm(CFG, total_rows=200, batch_size=50):\n",
|
||||
" \"\"\"Fixed LLM generation with adaptive batch processing\"\"\"\n",
|
||||
" print(f\"🚀 Generating {total_rows} survey responses with adaptive batching\")\n",
|
||||
" \n",
|
||||
" # Adaptive batch sizing based on total rows\n",
|
||||
" if total_rows <= 20:\n",
|
||||
" optimal_batch_size = min(batch_size, total_rows)\n",
|
||||
" elif total_rows <= 50:\n",
|
||||
" optimal_batch_size = min(15, batch_size)\n",
|
||||
" elif total_rows <= 100:\n",
|
||||
" optimal_batch_size = min(10, batch_size)\n",
|
||||
" else:\n",
|
||||
" optimal_batch_size = min(8, batch_size)\n",
|
||||
" \n",
|
||||
" print(f\"📊 Using optimal batch size: {optimal_batch_size}\")\n",
|
||||
" \n",
|
||||
" all_dataframes = []\n",
|
||||
" remaining = total_rows\n",
|
||||
" \n",
|
||||
" while remaining > 0:\n",
|
||||
" current_batch_size = min(optimal_batch_size, remaining)\n",
|
||||
" print(f\"\\n📦 Processing batch: {current_batch_size} rows (remaining: {remaining})\")\n",
|
||||
" \n",
|
||||
" try:\n",
|
||||
" batch_df = fixed_llm_generate_batch(CFG, current_batch_size)\n",
|
||||
" all_dataframes.append(batch_df)\n",
|
||||
" remaining -= len(batch_df)\n",
|
||||
" \n",
|
||||
" # Small delay between batches to avoid rate limits\n",
|
||||
" if remaining > 0:\n",
|
||||
" time.sleep(1.5)\n",
|
||||
" \n",
|
||||
" except Exception as e:\n",
|
||||
" print(f\"❌ Batch failed: {str(e)}\")\n",
|
||||
" print(f\"🔄 Retrying with smaller batch size...\")\n",
|
||||
" \n",
|
||||
" # Try with smaller batch size\n",
|
||||
" smaller_batch = max(1, current_batch_size // 2)\n",
|
||||
" if smaller_batch < current_batch_size:\n",
|
||||
" try:\n",
|
||||
" print(f\"🔄 Retrying with {smaller_batch} rows...\")\n",
|
||||
" batch_df = fixed_llm_generate_batch(CFG, smaller_batch)\n",
|
||||
" all_dataframes.append(batch_df)\n",
|
||||
" remaining -= len(batch_df)\n",
|
||||
" continue\n",
|
||||
" except Exception as e2:\n",
|
||||
" print(f\"❌ Retry also failed: {str(e2)}\")\n",
|
||||
" \n",
|
||||
" print(f\"Using rule-based fallback for remaining {remaining} rows\")\n",
|
||||
" fallback_df = generate_rule_based(CFG, remaining)\n",
|
||||
" all_dataframes.append(fallback_df)\n",
|
||||
" break\n",
|
||||
" \n",
|
||||
" if all_dataframes:\n",
|
||||
" result = pd.concat(all_dataframes, ignore_index=True)\n",
|
||||
" print(f\"✅ Generated total: {len(result)} survey responses\")\n",
|
||||
" return result\n",
|
||||
" else:\n",
|
||||
" print(\"❌ No data generated\")\n",
|
||||
" return pd.DataFrame()\n",
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "e1af410e",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Test the fixed LLM generation\n",
|
||||
"print(\"🧪 Testing LLM generation...\")\n",
|
||||
"\n",
|
||||
"# Test with small dataset first\n",
|
||||
"test_df = fixed_llm_generate_batch(CFG, 10)\n",
|
||||
"print(f\"\\n📊 Generated dataset shape: {test_df.shape}\")\n",
|
||||
"print(f\"\\n📋 First few rows:\")\n",
|
||||
"print(test_df.head())\n",
|
||||
"print(f\"\\n📈 Data types:\")\n",
|
||||
"print(test_df.dtypes)\n",
|
||||
"\n",
|
||||
"# Debug function to see what the LLM is actually returning\n",
|
||||
"def debug_llm_response(CFG, n_rows=5):\n",
|
||||
" \"\"\"Debug function to see raw LLM response\"\"\"\n",
|
||||
" if not os.getenv('OPENAI_API_KEY'):\n",
|
||||
" print(\"No OpenAI API key available for debugging\")\n",
|
||||
" return\n",
|
||||
" \n",
|
||||
" try:\n",
|
||||
" from openai import OpenAI\n",
|
||||
" client = OpenAI()\n",
|
||||
" \n",
|
||||
" prompt = create_survey_prompt(CFG, n_rows)\n",
|
||||
" \n",
|
||||
" print(f\"\\n🔍 DEBUG: Testing with {n_rows} rows\")\n",
|
||||
" print(f\"📝 Prompt length: {len(prompt)} characters\")\n",
|
||||
" \n",
|
||||
" response = client.chat.completions.create(\n",
|
||||
" model='gpt-4o-mini',\n",
|
||||
" messages=[\n",
|
||||
" {'role': 'system', 'content': 'You are a data generation expert. Generate realistic survey data in JSON format.'},\n",
|
||||
" {'role': 'user', 'content': prompt}\n",
|
||||
" ],\n",
|
||||
" temperature=0.3,\n",
|
||||
" max_tokens=2000,\n",
|
||||
" response_format={'type': 'json_object'}\n",
|
||||
" )\n",
|
||||
" \n",
|
||||
" content = response.choices[0].message.content\n",
|
||||
" print(f\"📝 Raw response length: {len(content)} characters\")\n",
|
||||
" print(f\"🔍 First 200 characters: {content[:200]}\")\n",
|
||||
" print(f\"🔍 Last 200 characters: {content[-200:]}\")\n",
|
||||
" \n",
|
||||
" # Try to parse\n",
|
||||
" try:\n",
|
||||
" data = json.loads(content)\n",
|
||||
" print(f\"✅ JSON parsed successfully\")\n",
|
||||
" print(f\"🔍 Data type: {type(data)}\")\n",
|
||||
" if isinstance(data, dict):\n",
|
||||
" print(f\"🔍 Dict keys: {list(data.keys())}\")\n",
|
||||
" elif isinstance(data, list):\n",
|
||||
" print(f\"🔍 List length: {len(data)}\")\n",
|
||||
" except Exception as e:\n",
|
||||
" print(f\"❌ JSON parsing failed: {str(e)}\")\n",
|
||||
" \n",
|
||||
" except Exception as e:\n",
|
||||
" print(f\"❌ Debug failed: {str(e)}\")\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "75c90739",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Test the fixed implementation\n",
|
||||
"print(\"🧪 Testing the fixed LLM generation...\")\n",
|
||||
"\n",
|
||||
"# Test with small dataset\n",
|
||||
"test_df = fixed_llm_generate_batch(CFG, 5)\n",
|
||||
"print(f\"\\n📊 Generated dataset shape: {test_df.shape}\")\n",
|
||||
"print(f\"\\n📋 First few rows:\")\n",
|
||||
"print(test_df.head())\n",
|
||||
"print(f\"\\n📈 Data types:\")\n",
|
||||
"print(test_df.dtypes)\n",
|
||||
"\n",
|
||||
"if not test_df.empty:\n",
|
||||
" print(f\"\\n✅ SUCCESS! LLM generation is now working!\")\n",
|
||||
" print(f\"📊 Generated {len(test_df)} survey responses using LLM\")\n",
|
||||
"else:\n",
|
||||
" print(f\"\\n❌ Still having issues with LLM generation\")\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "dd83b842",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"#Test larger dataset generation \n",
|
||||
"print(\"🚀 Testing larger dataset generation...\")\n",
|
||||
"large_df = fixed_generate_llm(CFG, total_rows=100, batch_size=25)\n",
|
||||
"if not large_df.empty:\n",
|
||||
" print(f\"\\n📊 Large dataset shape: {large_df.shape}\")\n",
|
||||
" print(f\"\\n📈 Summary statistics:\")\n",
|
||||
" print(large_df.describe())\n",
|
||||
" \n",
|
||||
" # Save the results\n",
|
||||
" from pathlib import Path\n",
|
||||
" out = Path(\"data\"); out.mkdir(exist_ok=True)\n",
|
||||
" ts = datetime.utcnow().strftime(\"%Y%m%dT%H%M%SZ\")\n",
|
||||
" csv_path = out / f\"survey_llm_fixed_{ts}.csv\"\n",
|
||||
" large_df.to_csv(csv_path, index=False)\n",
|
||||
" print(f\"💾 Saved: {csv_path}\")\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "6029d3e2",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"\n",
|
||||
"def build_json_schema(CFG):\n",
|
||||
" schema = {'type':'array','items':{'type':'object','properties':{},'required':[]}}\n",
|
||||
" props = schema['items']['properties']; req = schema['items']['required']\n",
|
||||
" for f in CFG['fields']:\n",
|
||||
" name, t = f['name'], f['type']\n",
|
||||
" req.append(name)\n",
|
||||
" if t in ('int','float'): props[name] = {'type':'number' if t=='float' else 'integer'}\n",
|
||||
" elif t == 'enum': props[name] = {'type':'string','enum': f['values']}\n",
|
||||
" elif t in ('uuid4','datetime'): props[name] = {'type':'string'}\n",
|
||||
" elif t == 'bool': props[name] = {'type':'boolean'}\n",
|
||||
" else: props[name] = {'type':'string'}\n",
|
||||
" return schema\n",
|
||||
"\n",
|
||||
"PROMPT_PREAMBLE = (\n",
|
||||
" \"You are a data generator. Return ONLY JSON. \"\n",
|
||||
" \"Respond as a JSON object with key 'rows' whose value is an array of exactly N objects. \"\n",
|
||||
" \"No prose, no code fences, no trailing commas.\"\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"def render_prompt(CFG, n_rows=100):\n",
|
||||
" minimal_cfg = {'fields': []}\n",
|
||||
" for f in CFG['fields']:\n",
|
||||
" base = {k: f[k] for k in ['name','type'] if k in f}\n",
|
||||
" if 'min' in f and 'max' in f: base.update({'min': f['min'], 'max': f['max']})\n",
|
||||
" if 'values' in f: base.update({'values': f['values']})\n",
|
||||
" if 'fmt' in f: base.update({'fmt': f['fmt']})\n",
|
||||
" minimal_cfg['fields'].append(base)\n",
|
||||
" return {\n",
|
||||
" 'preamble': PROMPT_PREAMBLE,\n",
|
||||
" 'n_rows': n_rows,\n",
|
||||
" 'schema': build_json_schema(CFG),\n",
|
||||
" 'constraints': minimal_cfg,\n",
|
||||
" 'instruction': f\"Return ONLY this structure: {{'rows': [ ... exactly {n_rows} objects ... ]}}\"\n",
|
||||
" }\n",
|
||||
"\n",
|
||||
"def parse_llm_json_to_df(raw: str) -> pd.DataFrame:\n",
|
||||
" try:\n",
|
||||
" obj = json.loads(raw)\n",
|
||||
" if isinstance(obj, dict) and isinstance(obj.get('rows'), list):\n",
|
||||
" return pd.DataFrame(obj['rows'])\n",
|
||||
" except Exception:\n",
|
||||
" pass\n",
|
||||
" data = extract_strict_json(raw)\n",
|
||||
" return pd.DataFrame(data)\n",
|
||||
"\n",
|
||||
"USE_LLM = bool(os.getenv('OPENAI_API_KEY'))\n",
|
||||
"print('LLM available:', USE_LLM)\n",
|
||||
"\n",
|
||||
"def llm_generate_batch(CFG, n_rows=50):\n",
|
||||
" if USE_LLM:\n",
|
||||
" try:\n",
|
||||
" from openai import OpenAI\n",
|
||||
" client = OpenAI()\n",
|
||||
" prompt = json.dumps(render_prompt(CFG, n_rows))\n",
|
||||
" resp = client.chat.completions.create(\n",
|
||||
" model='gpt-4o-mini',\n",
|
||||
" response_format={'type': 'json_object'},\n",
|
||||
" messages=[\n",
|
||||
" {'role':'system','content':'You output strict JSON only.'},\n",
|
||||
" {'role':'user','content': prompt}\n",
|
||||
" ],\n",
|
||||
" temperature=0.2,\n",
|
||||
" max_tokens=8192,\n",
|
||||
" )\n",
|
||||
" raw = resp.choices[0].message.content\n",
|
||||
" try:\n",
|
||||
" return parse_llm_json_to_df(raw)\n",
|
||||
" except Exception:\n",
|
||||
" stricter = (\n",
|
||||
" prompt\n",
|
||||
" + \"\\nReturn ONLY a JSON object structured as: \"\n",
|
||||
" + \"{\\\"rows\\\": [ ... exactly N objects ... ]}. \"\n",
|
||||
" + \"No prose, no explanations.\"\n",
|
||||
" )\n",
|
||||
" resp2 = client.chat.completions.create(\n",
|
||||
" model='gpt-4o-mini',\n",
|
||||
" response_format={'type': 'json_object'},\n",
|
||||
" messages=[\n",
|
||||
" {'role':'system','content':'You output strict JSON only.'},\n",
|
||||
" {'role':'user','content': stricter}\n",
|
||||
" ],\n",
|
||||
" temperature=0.2,\n",
|
||||
" max_tokens=8192,\n",
|
||||
" )\n",
|
||||
" raw2 = resp2.choices[0].message.content\n",
|
||||
" return parse_llm_json_to_df(raw2)\n",
|
||||
" except Exception as e:\n",
|
||||
" print('LLM error, fallback to rule-based mock:', e)\n",
|
||||
" tmp = dict(CFG); tmp['rows'] = n_rows\n",
|
||||
" return generate_rule_based(tmp)\n",
|
||||
"\n",
|
||||
"def generate_llm(CFG, total_rows=200, batch_size=50):\n",
|
||||
" dfs = []; remaining = total_rows\n",
|
||||
" while remaining > 0:\n",
|
||||
" b = min(batch_size, remaining)\n",
|
||||
" dfs.append(llm_generate_batch(CFG, n_rows=b))\n",
|
||||
" remaining -= b\n",
|
||||
" time.sleep(0.2)\n",
|
||||
" return pd.concat(dfs, ignore_index=True)\n",
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "2e759087",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"df_llm = generate_llm(CFG, total_rows=100, batch_size=50)\n",
|
||||
"df_llm.head()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "6d4908ad",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Test the improved LLM generation with adaptive batching\n",
|
||||
"print(\"🧪 Testing improved LLM generation with adaptive batching...\")\n",
|
||||
"\n",
|
||||
"# Test with smaller dataset first\n",
|
||||
"print(\"\\n📦 Testing small batch (10 rows)...\")\n",
|
||||
"small_df = fixed_llm_generate_batch(CFG, 10)\n",
|
||||
"print(f\"✅ Small batch result: {len(small_df)} rows\")\n",
|
||||
"\n",
|
||||
"# Test with medium dataset using adaptive batching\n",
|
||||
"print(\"\\n📦 Testing medium dataset (30 rows) with adaptive batching...\")\n",
|
||||
"medium_df = fixed_generate_llm(CFG, total_rows=30, batch_size=15)\n",
|
||||
"print(f\"✅ Medium dataset result: {len(medium_df)} rows\")\n",
|
||||
"\n",
|
||||
"if not medium_df.empty:\n",
|
||||
" print(f\"\\n📊 Dataset shape: {medium_df.shape}\")\n",
|
||||
" print(f\"\\n📋 First few rows:\")\n",
|
||||
" print(medium_df.head())\n",
|
||||
" \n",
|
||||
" # Save the results\n",
|
||||
" from pathlib import Path\n",
|
||||
" out = Path(\"data\"); out.mkdir(exist_ok=True)\n",
|
||||
" ts = datetime.utcnow().strftime(\"%Y%m%dT%H%M%SZ\")\n",
|
||||
" csv_path = out / f\"survey_adaptive_batch_{ts}.csv\"\n",
|
||||
" medium_df.to_csv(csv_path, index=False)\n",
|
||||
" print(f\"💾 Saved: {csv_path}\")\n",
|
||||
"else:\n",
|
||||
" print(\"❌ Medium dataset generation failed\")\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": ".venv",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.12.12"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
Reference in New Issue
Block a user