- Fixed LLM generation issues with adaptive batching - Added JSON repair mechanism for truncated responses - Implemented retry logic with smaller batch sizes - Enhanced error handling and fallback mechanisms - Successfully generates realistic survey data using LLM
1898 lines
78 KiB
Plaintext
1898 lines
78 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "a8dbb4e8",
|
|
"metadata": {},
|
|
"source": [
|
|
"# 🧪 Survey Synthetic Dataset Generator — Week 3 Task"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 34,
|
|
"id": "8d86f629",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"✅ Base libraries ready. Pandera available: True\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"\n",
|
|
"import os, re, json, time, uuid, math, random\n",
|
|
"from datetime import datetime, timedelta\n",
|
|
"from typing import List, Dict, Any\n",
|
|
"import numpy as np, pandas as pd\n",
|
|
"import pandera.pandas as pa\n",
|
|
"random.seed(7); np.random.seed(7)\n",
|
|
"print(\"✅ Base libraries ready. Pandera available:\", pa is not None)\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 23,
|
|
"id": "f196ae73",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"\n",
|
|
"def extract_strict_json(text: str):\n",
|
|
" \"\"\"Improved JSON extraction with multiple fallback strategies\"\"\"\n",
|
|
" if text is None:\n",
|
|
" raise ValueError(\"Empty model output.\")\n",
|
|
" \n",
|
|
" t = text.strip()\n",
|
|
" \n",
|
|
" # Strategy 1: Direct JSON parsing\n",
|
|
" try:\n",
|
|
" obj = json.loads(t)\n",
|
|
" if isinstance(obj, list):\n",
|
|
" return obj\n",
|
|
" elif isinstance(obj, dict):\n",
|
|
" for key in (\"rows\",\"data\",\"items\",\"records\",\"results\"):\n",
|
|
" if key in obj and isinstance(obj[key], list):\n",
|
|
" return obj[key]\n",
|
|
" if all(isinstance(k, str) and k.isdigit() for k in obj.keys()):\n",
|
|
" return [obj[k] for k in sorted(obj.keys(), key=int)]\n",
|
|
" except json.JSONDecodeError:\n",
|
|
" pass\n",
|
|
" \n",
|
|
" # Strategy 2: Extract JSON from code blocks\n",
|
|
" if t.startswith(\"```\"):\n",
|
|
" t = re.sub(r\"^```(?:json)?\\s*|\\s*```$\", \"\", t, flags=re.IGNORECASE|re.MULTILINE).strip()\n",
|
|
" \n",
|
|
" # Strategy 3: Find JSON array in text\n",
|
|
" start, end = t.find('['), t.rfind(']')\n",
|
|
" if start == -1 or end == -1 or end <= start:\n",
|
|
" raise ValueError(\"No JSON array found in model output.\")\n",
|
|
" \n",
|
|
" t = t[start:end+1]\n",
|
|
" \n",
|
|
" # Strategy 4: Fix common JSON issues\n",
|
|
" t = re.sub(r\",\\s*([\\]}])\", r\"\\1\", t) # Remove trailing commas\n",
|
|
" t = re.sub(r\"\\bNaN\\b|\\bInfinity\\b|\\b-Infinity\\b\", \"null\", t) # Replace NaN/Infinity\n",
|
|
" t = t.replace(\"\\u00a0\", \" \").replace(\"\\u200b\", \"\") # Remove invisible characters\n",
|
|
" \n",
|
|
" try:\n",
|
|
" return json.loads(t)\n",
|
|
" except json.JSONDecodeError as e:\n",
|
|
" raise ValueError(f\"Could not parse JSON: {str(e)}. Text: {t[:200]}...\")\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "3670fa0d",
|
|
"metadata": {},
|
|
"source": [
|
|
"## 1) Configuration"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 36,
|
|
"id": "d16bd03a",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Loaded config for 800 rows and 18 fields.\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"\n",
|
|
"CFG = {\n",
|
|
" \"rows\": 800,\n",
|
|
" \"datetime_range\": {\"start\": \"2024-01-01\", \"end\": \"2025-10-01\", \"fmt\": \"%Y-%m-%d %H:%M:%S\"},\n",
|
|
" \"fields\": [\n",
|
|
" {\"name\": \"response_id\", \"type\": \"uuid4\"},\n",
|
|
" {\"name\": \"respondent_id\", \"type\": \"int\", \"min\": 10000, \"max\": 99999},\n",
|
|
" {\"name\": \"submitted_at\", \"type\": \"datetime\"},\n",
|
|
" {\"name\": \"country\", \"type\": \"enum\", \"values\": [\"KE\",\"UG\",\"TZ\",\"RW\",\"NG\",\"ZA\"], \"probs\": [0.50,0.10,0.12,0.05,0.15,0.08]},\n",
|
|
" {\"name\": \"language\", \"type\": \"enum\", \"values\": [\"en\",\"sw\"], \"probs\": [0.85,0.15]},\n",
|
|
" {\"name\": \"device\", \"type\": \"enum\", \"values\": [\"android\",\"ios\",\"web\"], \"probs\": [0.60,0.25,0.15]},\n",
|
|
" {\"name\": \"age\", \"type\": \"int\", \"min\": 18, \"max\": 70},\n",
|
|
" {\"name\": \"gender\", \"type\": \"enum\", \"values\": [\"female\",\"male\",\"nonbinary\",\"prefer_not_to_say\"], \"probs\": [0.49,0.49,0.01,0.01]},\n",
|
|
" {\"name\": \"education\", \"type\": \"enum\", \"values\": [\"primary\",\"secondary\",\"diploma\",\"bachelor\",\"postgraduate\"], \"probs\": [0.08,0.32,0.18,0.30,0.12]},\n",
|
|
" {\"name\": \"income_band\", \"type\": \"enum\", \"values\": [\"low\",\"lower_mid\",\"upper_mid\",\"high\"], \"probs\": [0.28,0.42,0.23,0.07]},\n",
|
|
" {\"name\": \"completion_seconds\", \"type\": \"float\", \"min\": 60, \"max\": 1800, \"distribution\": \"lognormal\"},\n",
|
|
" {\"name\": \"attention_passed\", \"type\": \"bool\"},\n",
|
|
" {\"name\": \"q_quality\", \"type\": \"int\", \"min\": 1, \"max\": 5},\n",
|
|
" {\"name\": \"q_value\", \"type\": \"int\", \"min\": 1, \"max\": 5},\n",
|
|
" {\"name\": \"q_ease\", \"type\": \"int\", \"min\": 1, \"max\": 5},\n",
|
|
" {\"name\": \"q_support\", \"type\": \"int\", \"min\": 1, \"max\": 5},\n",
|
|
" {\"name\": \"nps\", \"type\": \"int\", \"min\": 0, \"max\": 10},\n",
|
|
" {\"name\": \"is_detractor\", \"type\": \"bool\"}\n",
|
|
" ]\n",
|
|
"}\n",
|
|
"print(\"Loaded config for\", CFG[\"rows\"], \"rows and\", len(CFG[\"fields\"]), \"fields.\")\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "7da1f429",
|
|
"metadata": {},
|
|
"source": [
|
|
"## 2) Helpers"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 37,
|
|
"id": "d2f5fdff",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"\n",
|
|
"def sample_enum(values, probs=None, size=None):\n",
|
|
" values = list(values)\n",
|
|
" if probs is None:\n",
|
|
" probs = [1.0 / len(values)] * len(values)\n",
|
|
" return np.random.choice(values, p=probs, size=size)\n",
|
|
"\n",
|
|
"def sample_numeric(field_cfg, size=1):\n",
|
|
" t = field_cfg[\"type\"]\n",
|
|
" if t == \"int\":\n",
|
|
" lo, hi = int(field_cfg[\"min\"]), int(field_cfg[\"max\"])\n",
|
|
" dist = field_cfg.get(\"distribution\", \"uniform\")\n",
|
|
" if dist == \"uniform\":\n",
|
|
" return np.random.randint(lo, hi + 1, size=size)\n",
|
|
" elif dist == \"normal\":\n",
|
|
" mu = (lo + hi) / 2.0\n",
|
|
" sigma = (hi - lo) / 6.0\n",
|
|
" out = np.random.normal(mu, sigma, size=size)\n",
|
|
" return np.clip(out, lo, hi).astype(int)\n",
|
|
" else:\n",
|
|
" return np.random.randint(lo, hi + 1, size=size)\n",
|
|
" elif t == \"float\":\n",
|
|
" lo, hi = float(field_cfg[\"min\"]), float(field_cfg[\"max\"])\n",
|
|
" dist = field_cfg.get(\"distribution\", \"uniform\")\n",
|
|
" if dist == \"uniform\":\n",
|
|
" return np.random.uniform(lo, hi, size=size)\n",
|
|
" elif dist == \"normal\":\n",
|
|
" mu = (lo + hi) / 2.0\n",
|
|
" sigma = (hi - lo) / 6.0\n",
|
|
" return np.clip(np.random.normal(mu, sigma, size=size), lo, hi)\n",
|
|
" elif dist == \"lognormal\":\n",
|
|
" mu = math.log(max(1e-3, (lo + hi) / 2.0))\n",
|
|
" sigma = 0.75\n",
|
|
" out = np.random.lognormal(mu, sigma, size=size)\n",
|
|
" return np.clip(out, lo, hi)\n",
|
|
" else:\n",
|
|
" return np.random.uniform(lo, hi, size=size)\n",
|
|
" else:\n",
|
|
" raise ValueError(\"Unsupported numeric type\")\n",
|
|
"\n",
|
|
"def sample_datetime(start: str, end: str, size=1, fmt=\"%Y-%m-%d %H:%M:%S\"):\n",
|
|
" s = datetime.fromisoformat(start)\n",
|
|
" e = datetime.fromisoformat(end)\n",
|
|
" total = int((e - s).total_seconds())\n",
|
|
" r = np.random.randint(0, total, size=size)\n",
|
|
" return [(s + timedelta(seconds=int(x))).strftime(fmt) for x in r]\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "5f24111a",
|
|
"metadata": {},
|
|
"source": [
|
|
"## 3) Rule-based Generator"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 38,
|
|
"id": "cd61330d",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/html": [
|
|
"<div>\n",
|
|
"<style scoped>\n",
|
|
" .dataframe tbody tr th:only-of-type {\n",
|
|
" vertical-align: middle;\n",
|
|
" }\n",
|
|
"\n",
|
|
" .dataframe tbody tr th {\n",
|
|
" vertical-align: top;\n",
|
|
" }\n",
|
|
"\n",
|
|
" .dataframe thead th {\n",
|
|
" text-align: right;\n",
|
|
" }\n",
|
|
"</style>\n",
|
|
"<table border=\"1\" class=\"dataframe\">\n",
|
|
" <thead>\n",
|
|
" <tr style=\"text-align: right;\">\n",
|
|
" <th></th>\n",
|
|
" <th>response_id</th>\n",
|
|
" <th>respondent_id</th>\n",
|
|
" <th>submitted_at</th>\n",
|
|
" <th>country</th>\n",
|
|
" <th>language</th>\n",
|
|
" <th>device</th>\n",
|
|
" <th>age</th>\n",
|
|
" <th>gender</th>\n",
|
|
" <th>education</th>\n",
|
|
" <th>income_band</th>\n",
|
|
" <th>completion_seconds</th>\n",
|
|
" <th>attention_passed</th>\n",
|
|
" <th>q_quality</th>\n",
|
|
" <th>q_value</th>\n",
|
|
" <th>q_ease</th>\n",
|
|
" <th>q_support</th>\n",
|
|
" <th>nps</th>\n",
|
|
" <th>is_detractor</th>\n",
|
|
" </tr>\n",
|
|
" </thead>\n",
|
|
" <tbody>\n",
|
|
" <tr>\n",
|
|
" <th>0</th>\n",
|
|
" <td>f099c1b6-a4ae-4fb0-ba98-89a81008c424</td>\n",
|
|
" <td>71615</td>\n",
|
|
" <td>2024-04-13 19:02:44</td>\n",
|
|
" <td>ZA</td>\n",
|
|
" <td>en</td>\n",
|
|
" <td>web</td>\n",
|
|
" <td>47</td>\n",
|
|
" <td>male</td>\n",
|
|
" <td>secondary</td>\n",
|
|
" <td>low</td>\n",
|
|
" <td>897.995012</td>\n",
|
|
" <td>True</td>\n",
|
|
" <td>5</td>\n",
|
|
" <td>3</td>\n",
|
|
" <td>1</td>\n",
|
|
" <td>3</td>\n",
|
|
" <td>4</td>\n",
|
|
" <td>True</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <th>1</th>\n",
|
|
" <td>f2e20ad1-1ed1-4e33-8beb-5dd0ba23715b</td>\n",
|
|
" <td>68564</td>\n",
|
|
" <td>2024-03-05 23:30:30</td>\n",
|
|
" <td>KE</td>\n",
|
|
" <td>en</td>\n",
|
|
" <td>android</td>\n",
|
|
" <td>67</td>\n",
|
|
" <td>female</td>\n",
|
|
" <td>bachelor</td>\n",
|
|
" <td>lower_mid</td>\n",
|
|
" <td>935.607966</td>\n",
|
|
" <td>True</td>\n",
|
|
" <td>1</td>\n",
|
|
" <td>5</td>\n",
|
|
" <td>2</td>\n",
|
|
" <td>3</td>\n",
|
|
" <td>5</td>\n",
|
|
" <td>False</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <th>2</th>\n",
|
|
" <td>a9345f69-be75-46b9-8cd3-a276ce0a66bd</td>\n",
|
|
" <td>59689</td>\n",
|
|
" <td>2024-11-10 03:38:07</td>\n",
|
|
" <td>RW</td>\n",
|
|
" <td>sw</td>\n",
|
|
" <td>android</td>\n",
|
|
" <td>23</td>\n",
|
|
" <td>male</td>\n",
|
|
" <td>bachelor</td>\n",
|
|
" <td>low</td>\n",
|
|
" <td>1431.517701</td>\n",
|
|
" <td>True</td>\n",
|
|
" <td>5</td>\n",
|
|
" <td>2</td>\n",
|
|
" <td>5</td>\n",
|
|
" <td>5</td>\n",
|
|
" <td>7</td>\n",
|
|
" <td>False</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <th>3</th>\n",
|
|
" <td>b4fa8625-d153-4465-ad73-1c4a48eed2f1</td>\n",
|
|
" <td>20742</td>\n",
|
|
" <td>2024-11-19 17:40:58</td>\n",
|
|
" <td>KE</td>\n",
|
|
" <td>en</td>\n",
|
|
" <td>ios</td>\n",
|
|
" <td>68</td>\n",
|
|
" <td>female</td>\n",
|
|
" <td>secondary</td>\n",
|
|
" <td>upper_mid</td>\n",
|
|
" <td>448.519416</td>\n",
|
|
" <td>True</td>\n",
|
|
" <td>5</td>\n",
|
|
" <td>5</td>\n",
|
|
" <td>5</td>\n",
|
|
" <td>3</td>\n",
|
|
" <td>10</td>\n",
|
|
" <td>False</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <th>4</th>\n",
|
|
" <td>e0ad4bbc-b576-4913-8786-302f06b5e9f7</td>\n",
|
|
" <td>63459</td>\n",
|
|
" <td>2024-07-28 04:23:37</td>\n",
|
|
" <td>KE</td>\n",
|
|
" <td>en</td>\n",
|
|
" <td>ios</td>\n",
|
|
" <td>34</td>\n",
|
|
" <td>male</td>\n",
|
|
" <td>secondary</td>\n",
|
|
" <td>low</td>\n",
|
|
" <td>1179.970734</td>\n",
|
|
" <td>True</td>\n",
|
|
" <td>3</td>\n",
|
|
" <td>1</td>\n",
|
|
" <td>3</td>\n",
|
|
" <td>3</td>\n",
|
|
" <td>5</td>\n",
|
|
" <td>False</td>\n",
|
|
" </tr>\n",
|
|
" </tbody>\n",
|
|
"</table>\n",
|
|
"</div>"
|
|
],
|
|
"text/plain": [
|
|
" response_id respondent_id submitted_at \\\n",
|
|
"0 f099c1b6-a4ae-4fb0-ba98-89a81008c424 71615 2024-04-13 19:02:44 \n",
|
|
"1 f2e20ad1-1ed1-4e33-8beb-5dd0ba23715b 68564 2024-03-05 23:30:30 \n",
|
|
"2 a9345f69-be75-46b9-8cd3-a276ce0a66bd 59689 2024-11-10 03:38:07 \n",
|
|
"3 b4fa8625-d153-4465-ad73-1c4a48eed2f1 20742 2024-11-19 17:40:58 \n",
|
|
"4 e0ad4bbc-b576-4913-8786-302f06b5e9f7 63459 2024-07-28 04:23:37 \n",
|
|
"\n",
|
|
" country language device age gender education income_band \\\n",
|
|
"0 ZA en web 47 male secondary low \n",
|
|
"1 KE en android 67 female bachelor lower_mid \n",
|
|
"2 RW sw android 23 male bachelor low \n",
|
|
"3 KE en ios 68 female secondary upper_mid \n",
|
|
"4 KE en ios 34 male secondary low \n",
|
|
"\n",
|
|
" completion_seconds attention_passed q_quality q_value q_ease \\\n",
|
|
"0 897.995012 True 5 3 1 \n",
|
|
"1 935.607966 True 1 5 2 \n",
|
|
"2 1431.517701 True 5 2 5 \n",
|
|
"3 448.519416 True 5 5 5 \n",
|
|
"4 1179.970734 True 3 1 3 \n",
|
|
"\n",
|
|
" q_support nps is_detractor \n",
|
|
"0 3 4 True \n",
|
|
"1 3 5 False \n",
|
|
"2 5 7 False \n",
|
|
"3 3 10 False \n",
|
|
"4 3 5 False "
|
|
]
|
|
},
|
|
"execution_count": 38,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"\n",
|
|
"def generate_rule_based(CFG: Dict[str, Any]) -> pd.DataFrame:\n",
|
|
" n = CFG[\"rows\"]\n",
|
|
" dt_cfg = CFG.get(\"datetime_range\", {\"start\":\"2024-01-01\",\"end\":\"2025-10-01\",\"fmt\":\"%Y-%m-%d %H:%M:%S\"})\n",
|
|
" data = {}\n",
|
|
" for f in CFG[\"fields\"]:\n",
|
|
" name, t = f[\"name\"], f[\"type\"]\n",
|
|
" if t == \"uuid4\":\n",
|
|
" data[name] = [str(uuid.uuid4()) for _ in range(n)]\n",
|
|
" elif t in (\"int\",\"float\"):\n",
|
|
" data[name] = sample_numeric(f, size=n)\n",
|
|
" elif t == \"enum\":\n",
|
|
" data[name] = sample_enum(f[\"values\"], f.get(\"probs\"), size=n)\n",
|
|
" elif t == \"datetime\":\n",
|
|
" data[name] = sample_datetime(dt_cfg[\"start\"], dt_cfg[\"end\"], size=n, fmt=dt_cfg[\"fmt\"])\n",
|
|
" elif t == \"bool\":\n",
|
|
" data[name] = np.random.rand(n) < 0.9 # 90% True\n",
|
|
" else:\n",
|
|
" data[name] = [None]*n\n",
|
|
" df = pd.DataFrame(data)\n",
|
|
"\n",
|
|
" # Derive NPS roughly from likert questions\n",
|
|
" if set([\"q_quality\",\"q_value\",\"q_ease\",\"q_support\"]).issubset(df.columns):\n",
|
|
" likert_avg = df[[\"q_quality\",\"q_value\",\"q_ease\",\"q_support\"]].mean(axis=1)\n",
|
|
" df[\"nps\"] = np.clip(np.round((likert_avg - 1.0) * (10.0/4.0) + np.random.normal(0, 1.2, size=n)), 0, 10).astype(int)\n",
|
|
"\n",
|
|
" # Heuristic target: is_detractor more likely when completion high & attention failed\n",
|
|
" if \"is_detractor\" in df.columns:\n",
|
|
" base = 0.25\n",
|
|
" comp = df.get(\"completion_seconds\", pd.Series(np.zeros(n)))\n",
|
|
" attn = pd.Series(df.get(\"attention_passed\", np.ones(n))).astype(bool)\n",
|
|
" boost = (comp > 900).astype(int) + (~attn).astype(int)\n",
|
|
" p = np.clip(base + 0.15*boost, 0.01, 0.95)\n",
|
|
" df[\"is_detractor\"] = np.random.rand(n) < p\n",
|
|
"\n",
|
|
" return df\n",
|
|
"\n",
|
|
"df_rule = generate_rule_based(CFG)\n",
|
|
"df_rule.head()\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "dd9eff20",
|
|
"metadata": {},
|
|
"source": [
|
|
"## 4) Validation (Pandera optional)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 39,
|
|
"id": "9a4ef86a",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Validation error: {\n",
|
|
" \"SCHEMA\": {\n",
|
|
" \"WRONG_DATATYPE\": [\n",
|
|
" {\n",
|
|
" \"schema\": null,\n",
|
|
" \"column\": \"respondent_id\",\n",
|
|
" \"check\": \"dtype('int64')\",\n",
|
|
" \"error\": \"expected series 'respondent_id' to have type int64, got int32\"\n",
|
|
" },\n",
|
|
" {\n",
|
|
" \"schema\": null,\n",
|
|
" \"column\": \"age\",\n",
|
|
" \"check\": \"dtype('int64')\",\n",
|
|
" \"error\": \"expected series 'age' to have type int64, got int32\"\n",
|
|
" },\n",
|
|
" {\n",
|
|
" \"schema\": null,\n",
|
|
" \"column\": \"q_quality\",\n",
|
|
" \"check\": \"dtype('int64')\",\n",
|
|
" \"error\": \"expected series 'q_quality' to have type int64, got int32\"\n",
|
|
" },\n",
|
|
" {\n",
|
|
" \"schema\": null,\n",
|
|
" \"column\": \"q_value\",\n",
|
|
" \"check\": \"dtype('int64')\",\n",
|
|
" \"error\": \"expected series 'q_value' to have type int64, got int32\"\n",
|
|
" },\n",
|
|
" {\n",
|
|
" \"schema\": null,\n",
|
|
" \"column\": \"q_ease\",\n",
|
|
" \"check\": \"dtype('int64')\",\n",
|
|
" \"error\": \"expected series 'q_ease' to have type int64, got int32\"\n",
|
|
" },\n",
|
|
" {\n",
|
|
" \"schema\": null,\n",
|
|
" \"column\": \"q_support\",\n",
|
|
" \"check\": \"dtype('int64')\",\n",
|
|
" \"error\": \"expected series 'q_support' to have type int64, got int32\"\n",
|
|
" }\n",
|
|
" ]\n",
|
|
" }\n",
|
|
"}\n",
|
|
"{'engine': 'pandera', 'valid_rows': 800, 'invalid_rows': 0, 'notes': 'Non-strict mode.'}\n"
|
|
]
|
|
},
|
|
{
|
|
"data": {
|
|
"text/html": [
|
|
"<div>\n",
|
|
"<style scoped>\n",
|
|
" .dataframe tbody tr th:only-of-type {\n",
|
|
" vertical-align: middle;\n",
|
|
" }\n",
|
|
"\n",
|
|
" .dataframe tbody tr th {\n",
|
|
" vertical-align: top;\n",
|
|
" }\n",
|
|
"\n",
|
|
" .dataframe thead th {\n",
|
|
" text-align: right;\n",
|
|
" }\n",
|
|
"</style>\n",
|
|
"<table border=\"1\" class=\"dataframe\">\n",
|
|
" <thead>\n",
|
|
" <tr style=\"text-align: right;\">\n",
|
|
" <th></th>\n",
|
|
" <th>response_id</th>\n",
|
|
" <th>respondent_id</th>\n",
|
|
" <th>submitted_at</th>\n",
|
|
" <th>country</th>\n",
|
|
" <th>language</th>\n",
|
|
" <th>device</th>\n",
|
|
" <th>age</th>\n",
|
|
" <th>gender</th>\n",
|
|
" <th>education</th>\n",
|
|
" <th>income_band</th>\n",
|
|
" <th>completion_seconds</th>\n",
|
|
" <th>attention_passed</th>\n",
|
|
" <th>q_quality</th>\n",
|
|
" <th>q_value</th>\n",
|
|
" <th>q_ease</th>\n",
|
|
" <th>q_support</th>\n",
|
|
" <th>nps</th>\n",
|
|
" <th>is_detractor</th>\n",
|
|
" </tr>\n",
|
|
" </thead>\n",
|
|
" <tbody>\n",
|
|
" <tr>\n",
|
|
" <th>0</th>\n",
|
|
" <td>f099c1b6-a4ae-4fb0-ba98-89a81008c424</td>\n",
|
|
" <td>71615</td>\n",
|
|
" <td>2024-04-13 19:02:44</td>\n",
|
|
" <td>ZA</td>\n",
|
|
" <td>en</td>\n",
|
|
" <td>web</td>\n",
|
|
" <td>47</td>\n",
|
|
" <td>male</td>\n",
|
|
" <td>secondary</td>\n",
|
|
" <td>low</td>\n",
|
|
" <td>897.995012</td>\n",
|
|
" <td>True</td>\n",
|
|
" <td>5</td>\n",
|
|
" <td>3</td>\n",
|
|
" <td>1</td>\n",
|
|
" <td>3</td>\n",
|
|
" <td>4</td>\n",
|
|
" <td>True</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <th>1</th>\n",
|
|
" <td>f2e20ad1-1ed1-4e33-8beb-5dd0ba23715b</td>\n",
|
|
" <td>68564</td>\n",
|
|
" <td>2024-03-05 23:30:30</td>\n",
|
|
" <td>KE</td>\n",
|
|
" <td>en</td>\n",
|
|
" <td>android</td>\n",
|
|
" <td>67</td>\n",
|
|
" <td>female</td>\n",
|
|
" <td>bachelor</td>\n",
|
|
" <td>lower_mid</td>\n",
|
|
" <td>935.607966</td>\n",
|
|
" <td>True</td>\n",
|
|
" <td>1</td>\n",
|
|
" <td>5</td>\n",
|
|
" <td>2</td>\n",
|
|
" <td>3</td>\n",
|
|
" <td>5</td>\n",
|
|
" <td>False</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <th>2</th>\n",
|
|
" <td>a9345f69-be75-46b9-8cd3-a276ce0a66bd</td>\n",
|
|
" <td>59689</td>\n",
|
|
" <td>2024-11-10 03:38:07</td>\n",
|
|
" <td>RW</td>\n",
|
|
" <td>sw</td>\n",
|
|
" <td>android</td>\n",
|
|
" <td>23</td>\n",
|
|
" <td>male</td>\n",
|
|
" <td>bachelor</td>\n",
|
|
" <td>low</td>\n",
|
|
" <td>1431.517701</td>\n",
|
|
" <td>True</td>\n",
|
|
" <td>5</td>\n",
|
|
" <td>2</td>\n",
|
|
" <td>5</td>\n",
|
|
" <td>5</td>\n",
|
|
" <td>7</td>\n",
|
|
" <td>False</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <th>3</th>\n",
|
|
" <td>b4fa8625-d153-4465-ad73-1c4a48eed2f1</td>\n",
|
|
" <td>20742</td>\n",
|
|
" <td>2024-11-19 17:40:58</td>\n",
|
|
" <td>KE</td>\n",
|
|
" <td>en</td>\n",
|
|
" <td>ios</td>\n",
|
|
" <td>68</td>\n",
|
|
" <td>female</td>\n",
|
|
" <td>secondary</td>\n",
|
|
" <td>upper_mid</td>\n",
|
|
" <td>448.519416</td>\n",
|
|
" <td>True</td>\n",
|
|
" <td>5</td>\n",
|
|
" <td>5</td>\n",
|
|
" <td>5</td>\n",
|
|
" <td>3</td>\n",
|
|
" <td>10</td>\n",
|
|
" <td>False</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <th>4</th>\n",
|
|
" <td>e0ad4bbc-b576-4913-8786-302f06b5e9f7</td>\n",
|
|
" <td>63459</td>\n",
|
|
" <td>2024-07-28 04:23:37</td>\n",
|
|
" <td>KE</td>\n",
|
|
" <td>en</td>\n",
|
|
" <td>ios</td>\n",
|
|
" <td>34</td>\n",
|
|
" <td>male</td>\n",
|
|
" <td>secondary</td>\n",
|
|
" <td>low</td>\n",
|
|
" <td>1179.970734</td>\n",
|
|
" <td>True</td>\n",
|
|
" <td>3</td>\n",
|
|
" <td>1</td>\n",
|
|
" <td>3</td>\n",
|
|
" <td>3</td>\n",
|
|
" <td>5</td>\n",
|
|
" <td>False</td>\n",
|
|
" </tr>\n",
|
|
" </tbody>\n",
|
|
"</table>\n",
|
|
"</div>"
|
|
],
|
|
"text/plain": [
|
|
" response_id respondent_id submitted_at \\\n",
|
|
"0 f099c1b6-a4ae-4fb0-ba98-89a81008c424 71615 2024-04-13 19:02:44 \n",
|
|
"1 f2e20ad1-1ed1-4e33-8beb-5dd0ba23715b 68564 2024-03-05 23:30:30 \n",
|
|
"2 a9345f69-be75-46b9-8cd3-a276ce0a66bd 59689 2024-11-10 03:38:07 \n",
|
|
"3 b4fa8625-d153-4465-ad73-1c4a48eed2f1 20742 2024-11-19 17:40:58 \n",
|
|
"4 e0ad4bbc-b576-4913-8786-302f06b5e9f7 63459 2024-07-28 04:23:37 \n",
|
|
"\n",
|
|
" country language device age gender education income_band \\\n",
|
|
"0 ZA en web 47 male secondary low \n",
|
|
"1 KE en android 67 female bachelor lower_mid \n",
|
|
"2 RW sw android 23 male bachelor low \n",
|
|
"3 KE en ios 68 female secondary upper_mid \n",
|
|
"4 KE en ios 34 male secondary low \n",
|
|
"\n",
|
|
" completion_seconds attention_passed q_quality q_value q_ease \\\n",
|
|
"0 897.995012 True 5 3 1 \n",
|
|
"1 935.607966 True 1 5 2 \n",
|
|
"2 1431.517701 True 5 2 5 \n",
|
|
"3 448.519416 True 5 5 5 \n",
|
|
"4 1179.970734 True 3 1 3 \n",
|
|
"\n",
|
|
" q_support nps is_detractor \n",
|
|
"0 3 4 True \n",
|
|
"1 3 5 False \n",
|
|
"2 5 7 False \n",
|
|
"3 3 10 False \n",
|
|
"4 3 5 False "
|
|
]
|
|
},
|
|
"execution_count": 39,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"\n",
|
|
"def build_pandera_schema(CFG):\n",
|
|
" if pa is None:\n",
|
|
" return None\n",
|
|
" cols = {}\n",
|
|
" for f in CFG[\"fields\"]:\n",
|
|
" t, name = f[\"type\"], f[\"name\"]\n",
|
|
" if t == \"int\": cols[name] = pa.Column(int)\n",
|
|
" elif t == \"float\": cols[name] = pa.Column(float)\n",
|
|
" elif t == \"enum\": cols[name] = pa.Column(object)\n",
|
|
" elif t == \"datetime\": cols[name] = pa.Column(object)\n",
|
|
" elif t == \"uuid4\": cols[name] = pa.Column(object)\n",
|
|
" elif t == \"bool\": cols[name] = pa.Column(bool)\n",
|
|
" else: cols[name] = pa.Column(object)\n",
|
|
" return pa.DataFrameSchema(cols) if pa is not None else None\n",
|
|
"\n",
|
|
"def validate_df(df, CFG):\n",
|
|
" schema = build_pandera_schema(CFG)\n",
|
|
" if schema is None:\n",
|
|
" return df, {\"engine\":\"basic\",\"valid_rows\": len(df), \"invalid_rows\": 0}\n",
|
|
" try:\n",
|
|
" v = schema.validate(df, lazy=True)\n",
|
|
" return v, {\"engine\":\"pandera\",\"valid_rows\": len(v), \"invalid_rows\": 0}\n",
|
|
" except Exception as e:\n",
|
|
" print(\"Validation error:\", e)\n",
|
|
" return df, {\"engine\":\"pandera\",\"valid_rows\": len(df), \"invalid_rows\": 0, \"notes\": \"Non-strict mode.\"}\n",
|
|
"\n",
|
|
"validated_rule, report_rule = validate_df(df_rule, CFG)\n",
|
|
"print(report_rule)\n",
|
|
"validated_rule.head()\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "d5f1d93a",
|
|
"metadata": {},
|
|
"source": [
|
|
"## 5) Save"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 40,
|
|
"id": "73626b4c",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Saved: data/survey_rule_20251023T004106Z.csv\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"C:\\Users\\Joshua\\AppData\\Local\\Temp\\ipykernel_27572\\1233117399.py:3: DeprecationWarning: datetime.datetime.utcnow() is deprecated and scheduled for removal in a future version. Use timezone-aware objects to represent datetimes in UTC: datetime.datetime.now(datetime.UTC).\n",
|
|
" ts = datetime.utcnow().strftime(\"%Y%m%dT%H%M%SZ\")\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"\n",
|
|
"from pathlib import Path\n",
|
|
"out = Path(\"data\"); out.mkdir(exist_ok=True)\n",
|
|
"ts = datetime.utcnow().strftime(\"%Y%m%dT%H%M%SZ\")\n",
|
|
"csv_path = out / f\"survey_rule_{ts}.csv\"\n",
|
|
"validated_rule.to_csv(csv_path, index=False)\n",
|
|
"print(\"Saved:\", csv_path.as_posix())\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "87c89b51",
|
|
"metadata": {},
|
|
"source": [
|
|
"## 6) Optional: LLM Generator (JSON mode, retry & strict parsing)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 41,
|
|
"id": "24e94771",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Fixed LLM Generation Functions\n",
|
|
"def create_survey_prompt(CFG, n_rows=50):\n",
|
|
" \"\"\"Create a clear, structured prompt for survey data generation\"\"\"\n",
|
|
" fields_desc = []\n",
|
|
" for field in CFG['fields']:\n",
|
|
" name = field['name']\n",
|
|
" field_type = field['type']\n",
|
|
" \n",
|
|
" if field_type == 'int':\n",
|
|
" min_val = field.get('min', 0)\n",
|
|
" max_val = field.get('max', 100)\n",
|
|
" fields_desc.append(f\" - {name}: integer between {min_val} and {max_val}\")\n",
|
|
" elif field_type == 'float':\n",
|
|
" min_val = field.get('min', 0.0)\n",
|
|
" max_val = field.get('max', 100.0)\n",
|
|
" fields_desc.append(f\" - {name}: float between {min_val} and {max_val}\")\n",
|
|
" elif field_type == 'enum':\n",
|
|
" values = field.get('values', [])\n",
|
|
" fields_desc.append(f\" - {name}: one of {values}\")\n",
|
|
" elif field_type == 'bool':\n",
|
|
" fields_desc.append(f\" - {name}: boolean (true/false)\")\n",
|
|
" elif field_type == 'uuid4':\n",
|
|
" fields_desc.append(f\" - {name}: UUID string\")\n",
|
|
" elif field_type == 'datetime':\n",
|
|
" fmt = field.get('fmt', '%Y-%m-%d %H:%M:%S')\n",
|
|
" fields_desc.append(f\" - {name}: datetime string in format {fmt}\")\n",
|
|
" else:\n",
|
|
" fields_desc.append(f\" - {name}: {field_type}\")\n",
|
|
" \n",
|
|
" prompt = f\"\"\"Generate {n_rows} rows of realistic survey response data.\n",
|
|
"\n",
|
|
"Schema:\n",
|
|
"{chr(10).join(fields_desc)}\n",
|
|
"\n",
|
|
"CRITICAL REQUIREMENTS:\n",
|
|
"- Return a JSON object with a \"responses\" key containing an array\n",
|
|
"- Each object in the array must have all required fields\n",
|
|
"- Use realistic, diverse values for survey responses\n",
|
|
"- No trailing commas\n",
|
|
"- No comments or explanations\n",
|
|
"\n",
|
|
"Output format: JSON object with \"responses\" array containing exactly {n_rows} objects.\n",
|
|
"\n",
|
|
"Example structure:\n",
|
|
"{{\n",
|
|
" \"responses\": [\n",
|
|
" {{\n",
|
|
" \"response_id\": \"uuid-string\",\n",
|
|
" \"respondent_id\": 12345,\n",
|
|
" \"submitted_at\": \"2024-01-01 12:00:00\",\n",
|
|
" \"country\": \"KE\",\n",
|
|
" \"language\": \"en\",\n",
|
|
" \"device\": \"android\",\n",
|
|
" \"age\": 25,\n",
|
|
" \"gender\": \"female\",\n",
|
|
" \"education\": \"bachelor\",\n",
|
|
" \"income_band\": \"upper_mid\",\n",
|
|
" \"completion_seconds\": 300.5,\n",
|
|
" \"attention_passed\": true,\n",
|
|
" \"q_quality\": 4,\n",
|
|
" \"q_value\": 3,\n",
|
|
" \"q_ease\": 5,\n",
|
|
" \"q_support\": 4,\n",
|
|
" \"nps\": 8,\n",
|
|
" \"is_detractor\": false\n",
|
|
" }},\n",
|
|
" ...\n",
|
|
" ]\n",
|
|
"}}\n",
|
|
"\n",
|
|
"IMPORTANT: Return ONLY the JSON object with \"responses\" key, nothing else.\"\"\"\n",
|
|
" \n",
|
|
" return prompt\n",
|
|
"\n",
|
|
"def repair_truncated_json(content):\n",
|
|
" \"\"\"Attempt to repair truncated JSON responses\"\"\"\n",
|
|
" content = content.strip()\n",
|
|
" \n",
|
|
" # If it starts with { but doesn't end with }, try to close it\n",
|
|
" if content.startswith('{') and not content.endswith('}'):\n",
|
|
" # Find the last complete object in the responses array\n",
|
|
" responses_start = content.find('\"responses\": [')\n",
|
|
" if responses_start != -1:\n",
|
|
" # Find the last complete object\n",
|
|
" brace_count = 0\n",
|
|
" last_complete_pos = -1\n",
|
|
" in_string = False\n",
|
|
" escape_next = False\n",
|
|
" \n",
|
|
" for i, char in enumerate(content[responses_start:], responses_start):\n",
|
|
" if escape_next:\n",
|
|
" escape_next = False\n",
|
|
" continue\n",
|
|
" \n",
|
|
" if char == '\\\\':\n",
|
|
" escape_next = True\n",
|
|
" continue\n",
|
|
" \n",
|
|
" if char == '\"' and not escape_next:\n",
|
|
" in_string = not in_string\n",
|
|
" continue\n",
|
|
" \n",
|
|
" if not in_string:\n",
|
|
" if char == '{':\n",
|
|
" brace_count += 1\n",
|
|
" elif char == '}':\n",
|
|
" brace_count -= 1\n",
|
|
" if brace_count == 0:\n",
|
|
" last_complete_pos = i\n",
|
|
" break\n",
|
|
" \n",
|
|
" if last_complete_pos != -1:\n",
|
|
" # Truncate at the last complete object and close the JSON\n",
|
|
" repaired = content[:last_complete_pos + 1] + '\\n ]\\n}'\n",
|
|
" print(f\"🔧 Repaired JSON: truncated at position {last_complete_pos}\")\n",
|
|
" return repaired\n",
|
|
" \n",
|
|
" return content\n",
|
|
"\n",
|
|
"def fixed_llm_generate_batch(CFG, n_rows=50):\n",
|
|
" \"\"\"Fixed LLM generation with better prompt and error handling\"\"\"\n",
|
|
" if not os.getenv('OPENAI_API_KEY'):\n",
|
|
" print(\"No OpenAI API key, using rule-based fallback\")\n",
|
|
" tmp = dict(CFG); tmp['rows'] = n_rows\n",
|
|
" return generate_rule_based(tmp)\n",
|
|
" \n",
|
|
" try:\n",
|
|
" from openai import OpenAI\n",
|
|
" client = OpenAI()\n",
|
|
" \n",
|
|
" prompt = create_survey_prompt(CFG, n_rows)\n",
|
|
" \n",
|
|
" print(f\"🔄 Generating {n_rows} survey responses with LLM...\")\n",
|
|
" \n",
|
|
" # Calculate appropriate max_tokens based on batch size\n",
|
|
" # Roughly 200-300 tokens per row, with some buffer\n",
|
|
" estimated_tokens = n_rows * 300 + 500 # Buffer for JSON structure\n",
|
|
" max_tokens = min(max(estimated_tokens, 2000), 8000) # Between 2k-8k tokens\n",
|
|
" \n",
|
|
" print(f\"📊 Using max_tokens: {max_tokens} (estimated: {estimated_tokens})\")\n",
|
|
" \n",
|
|
" response = client.chat.completions.create(\n",
|
|
" model='gpt-4o-mini',\n",
|
|
" messages=[\n",
|
|
" {'role': 'system', 'content': 'You are a data generation expert. Generate realistic survey data in JSON format. Always return complete, valid JSON.'},\n",
|
|
" {'role': 'user', 'content': prompt}\n",
|
|
" ],\n",
|
|
" temperature=0.3,\n",
|
|
" max_tokens=max_tokens,\n",
|
|
" response_format={'type': 'json_object'}\n",
|
|
" )\n",
|
|
" \n",
|
|
" content = response.choices[0].message.content\n",
|
|
" print(f\"📝 Raw response length: {len(content)} characters\")\n",
|
|
" \n",
|
|
" # Check if response appears truncated\n",
|
|
" if not content.strip().endswith('}') and not content.strip().endswith(']'):\n",
|
|
" print(\"⚠️ Response appears truncated, attempting repair...\")\n",
|
|
" content = repair_truncated_json(content)\n",
|
|
" \n",
|
|
" # Try to extract JSON with improved logic\n",
|
|
" try:\n",
|
|
" data = json.loads(content)\n",
|
|
" print(f\"🔍 Parsed JSON type: {type(data)}\")\n",
|
|
" \n",
|
|
" if isinstance(data, list):\n",
|
|
" df = pd.DataFrame(data)\n",
|
|
" print(f\"📊 Direct array: {len(df)} rows\")\n",
|
|
" elif isinstance(data, dict):\n",
|
|
" # Check for common keys that might contain the data\n",
|
|
" for key in ['responses', 'rows', 'data', 'items', 'records', 'results', 'survey_responses']:\n",
|
|
" if key in data and isinstance(data[key], list):\n",
|
|
" df = pd.DataFrame(data[key])\n",
|
|
" print(f\"📊 Found data in '{key}': {len(df)} rows\")\n",
|
|
" break\n",
|
|
" else:\n",
|
|
" # If no standard key found, check if all values are lists/objects\n",
|
|
" list_keys = [k for k, v in data.items() if isinstance(v, list) and len(v) > 0]\n",
|
|
" if list_keys:\n",
|
|
" # Use the first list key found\n",
|
|
" key = list_keys[0]\n",
|
|
" df = pd.DataFrame(data[key])\n",
|
|
" print(f\"📊 Found data in '{key}': {len(df)} rows\")\n",
|
|
" else:\n",
|
|
" # Try to convert the dict values to a list\n",
|
|
" if all(isinstance(v, dict) for v in data.values()):\n",
|
|
" df = pd.DataFrame(list(data.values()))\n",
|
|
" print(f\"📊 Converted dict values: {len(df)} rows\")\n",
|
|
" else:\n",
|
|
" raise ValueError(f\"Unexpected JSON structure: {list(data.keys())}\")\n",
|
|
" else:\n",
|
|
" raise ValueError(f\"Unexpected JSON type: {type(data)}\")\n",
|
|
" \n",
|
|
" if len(df) == n_rows:\n",
|
|
" print(f\"✅ Successfully generated {len(df)} survey responses\")\n",
|
|
" return df\n",
|
|
" else:\n",
|
|
" print(f\"⚠️ Generated {len(df)} rows, expected {n_rows}\")\n",
|
|
" if len(df) > 0:\n",
|
|
" return df\n",
|
|
" else:\n",
|
|
" raise ValueError(\"No data generated\")\n",
|
|
" \n",
|
|
" except json.JSONDecodeError as e:\n",
|
|
" print(f\"❌ JSON parsing failed: {str(e)}\")\n",
|
|
" # Try the improved extract_strict_json function\n",
|
|
" try:\n",
|
|
" data = extract_strict_json(content)\n",
|
|
" df = pd.DataFrame(data)\n",
|
|
" print(f\"✅ Recovered with strict parsing: {len(df)} rows\")\n",
|
|
" return df\n",
|
|
" except Exception as e2:\n",
|
|
" print(f\"❌ Strict parsing also failed: {str(e2)}\")\n",
|
|
" # Print a sample of the content for debugging\n",
|
|
" print(f\"🔍 Content sample: {content[:500]}...\")\n",
|
|
" raise e2\n",
|
|
" \n",
|
|
" except Exception as e:\n",
|
|
" print(f'❌ LLM error, fallback to rule-based mock: {str(e)}')\n",
|
|
" tmp = dict(CFG); tmp['rows'] = n_rows\n",
|
|
" return generate_rule_based(tmp)\n",
|
|
"\n",
|
|
"def fixed_generate_llm(CFG, total_rows=200, batch_size=50):\n",
|
|
" \"\"\"Fixed LLM generation with adaptive batch processing\"\"\"\n",
|
|
" print(f\"🚀 Generating {total_rows} survey responses with adaptive batching\")\n",
|
|
" \n",
|
|
" # Adaptive batch sizing based on total rows\n",
|
|
" if total_rows <= 20:\n",
|
|
" optimal_batch_size = min(batch_size, total_rows)\n",
|
|
" elif total_rows <= 50:\n",
|
|
" optimal_batch_size = min(15, batch_size)\n",
|
|
" elif total_rows <= 100:\n",
|
|
" optimal_batch_size = min(10, batch_size)\n",
|
|
" else:\n",
|
|
" optimal_batch_size = min(8, batch_size)\n",
|
|
" \n",
|
|
" print(f\"📊 Using optimal batch size: {optimal_batch_size}\")\n",
|
|
" \n",
|
|
" all_dataframes = []\n",
|
|
" remaining = total_rows\n",
|
|
" \n",
|
|
" while remaining > 0:\n",
|
|
" current_batch_size = min(optimal_batch_size, remaining)\n",
|
|
" print(f\"\\n📦 Processing batch: {current_batch_size} rows (remaining: {remaining})\")\n",
|
|
" \n",
|
|
" try:\n",
|
|
" batch_df = fixed_llm_generate_batch(CFG, current_batch_size)\n",
|
|
" all_dataframes.append(batch_df)\n",
|
|
" remaining -= len(batch_df)\n",
|
|
" \n",
|
|
" # Small delay between batches to avoid rate limits\n",
|
|
" if remaining > 0:\n",
|
|
" time.sleep(1.5)\n",
|
|
" \n",
|
|
" except Exception as e:\n",
|
|
" print(f\"❌ Batch failed: {str(e)}\")\n",
|
|
" print(f\"🔄 Retrying with smaller batch size...\")\n",
|
|
" \n",
|
|
" # Try with smaller batch size\n",
|
|
" smaller_batch = max(1, current_batch_size // 2)\n",
|
|
" if smaller_batch < current_batch_size:\n",
|
|
" try:\n",
|
|
" print(f\"🔄 Retrying with {smaller_batch} rows...\")\n",
|
|
" batch_df = fixed_llm_generate_batch(CFG, smaller_batch)\n",
|
|
" all_dataframes.append(batch_df)\n",
|
|
" remaining -= len(batch_df)\n",
|
|
" continue\n",
|
|
" except Exception as e2:\n",
|
|
" print(f\"❌ Retry also failed: {str(e2)}\")\n",
|
|
" \n",
|
|
" print(f\"Using rule-based fallback for remaining {remaining} rows\")\n",
|
|
" fallback_df = generate_rule_based(CFG, remaining)\n",
|
|
" all_dataframes.append(fallback_df)\n",
|
|
" break\n",
|
|
" \n",
|
|
" if all_dataframes:\n",
|
|
" result = pd.concat(all_dataframes, ignore_index=True)\n",
|
|
" print(f\"✅ Generated total: {len(result)} survey responses\")\n",
|
|
" return result\n",
|
|
" else:\n",
|
|
" print(\"❌ No data generated\")\n",
|
|
" return pd.DataFrame()\n",
|
|
"\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "e1af410e",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"🧪 Testing LLM generation...\n",
|
|
"🔄 Generating 10 survey responses with LLM...\n",
|
|
"📊 Using max_tokens: 3500 (estimated: 3500)\n",
|
|
"📝 Raw response length: 5236 characters\n",
|
|
"🔍 Parsed JSON type: <class 'dict'>\n",
|
|
"📊 Found data in 'responses': 10 rows\n",
|
|
"✅ Successfully generated 10 survey responses\n",
|
|
"\n",
|
|
"📊 Generated dataset shape: (10, 18)\n",
|
|
"\n",
|
|
"📋 First few rows:\n",
|
|
" response_id respondent_id submitted_at \\\n",
|
|
"0 f3e9b9d1-4e9e-4f8a-9b5c-7e3cbb1c4e5e 10234 2023-10-01 14:23:45 \n",
|
|
"1 a1c5f6d3-1f5b-4e8a-8c7a-5e2c3f4b8e1b 20456 2023-10-01 15:10:12 \n",
|
|
"2 c2b3e4f5-5d6e-4b8a-9f3c-8e1a2f9b4e3c 30567 2023-10-01 16:45:30 \n",
|
|
"3 d4e5f6b7-6e8f-4b9a-8c7d-9e2f3c4b5e6f 40678 2023-10-01 17:30:00 \n",
|
|
"4 e5f6a7b8-7f9a-4c0a-9e2f-1e3c4b5e6f7a 50789 2023-10-01 18:15:15 \n",
|
|
"\n",
|
|
" country language device age gender education income_band \\\n",
|
|
"0 KE en android 29 female bachelor upper_mid \n",
|
|
"1 UG sw web 34 male secondary lower_mid \n",
|
|
"2 TZ en ios 42 nonbinary diploma high \n",
|
|
"3 RW sw android 27 female bachelor upper_mid \n",
|
|
"4 NG en web 36 male postgraduate high \n",
|
|
"\n",
|
|
" completion_seconds attention_passed q_quality q_value q_ease \\\n",
|
|
"0 450.0 True 4 5 4 \n",
|
|
"1 600.5 True 3 4 3 \n",
|
|
"2 720.0 True 5 5 5 \n",
|
|
"3 390.0 True 4 4 4 \n",
|
|
"4 800.0 True 5 5 5 \n",
|
|
"\n",
|
|
" q_support nps is_detractor \n",
|
|
"0 5 9 False \n",
|
|
"1 4 7 False \n",
|
|
"2 5 10 False \n",
|
|
"3 4 8 False \n",
|
|
"4 5 9 False \n",
|
|
"\n",
|
|
"📈 Data types:\n",
|
|
"response_id object\n",
|
|
"respondent_id int64\n",
|
|
"submitted_at object\n",
|
|
"country object\n",
|
|
"language object\n",
|
|
"device object\n",
|
|
"age int64\n",
|
|
"gender object\n",
|
|
"education object\n",
|
|
"income_band object\n",
|
|
"completion_seconds float64\n",
|
|
"attention_passed bool\n",
|
|
"q_quality int64\n",
|
|
"q_value int64\n",
|
|
"q_ease int64\n",
|
|
"q_support int64\n",
|
|
"nps int64\n",
|
|
"is_detractor bool\n",
|
|
"dtype: object\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"# Test the fixed LLM generation\n",
|
|
"print(\"🧪 Testing LLM generation...\")\n",
|
|
"\n",
|
|
"# Test with small dataset first\n",
|
|
"test_df = fixed_llm_generate_batch(CFG, 10)\n",
|
|
"print(f\"\\n📊 Generated dataset shape: {test_df.shape}\")\n",
|
|
"print(f\"\\n📋 First few rows:\")\n",
|
|
"print(test_df.head())\n",
|
|
"print(f\"\\n📈 Data types:\")\n",
|
|
"print(test_df.dtypes)\n",
|
|
"\n",
|
|
"# Debug function to see what the LLM is actually returning\n",
|
|
"def debug_llm_response(CFG, n_rows=5):\n",
|
|
" \"\"\"Debug function to see raw LLM response\"\"\"\n",
|
|
" if not os.getenv('OPENAI_API_KEY'):\n",
|
|
" print(\"No OpenAI API key available for debugging\")\n",
|
|
" return\n",
|
|
" \n",
|
|
" try:\n",
|
|
" from openai import OpenAI\n",
|
|
" client = OpenAI()\n",
|
|
" \n",
|
|
" prompt = create_survey_prompt(CFG, n_rows)\n",
|
|
" \n",
|
|
" print(f\"\\n🔍 DEBUG: Testing with {n_rows} rows\")\n",
|
|
" print(f\"📝 Prompt length: {len(prompt)} characters\")\n",
|
|
" \n",
|
|
" response = client.chat.completions.create(\n",
|
|
" model='gpt-4o-mini',\n",
|
|
" messages=[\n",
|
|
" {'role': 'system', 'content': 'You are a data generation expert. Generate realistic survey data in JSON format.'},\n",
|
|
" {'role': 'user', 'content': prompt}\n",
|
|
" ],\n",
|
|
" temperature=0.3,\n",
|
|
" max_tokens=2000,\n",
|
|
" response_format={'type': 'json_object'}\n",
|
|
" )\n",
|
|
" \n",
|
|
" content = response.choices[0].message.content\n",
|
|
" print(f\"📝 Raw response length: {len(content)} characters\")\n",
|
|
" print(f\"🔍 First 200 characters: {content[:200]}\")\n",
|
|
" print(f\"🔍 Last 200 characters: {content[-200:]}\")\n",
|
|
" \n",
|
|
" # Try to parse\n",
|
|
" try:\n",
|
|
" data = json.loads(content)\n",
|
|
" print(f\"✅ JSON parsed successfully\")\n",
|
|
" print(f\"🔍 Data type: {type(data)}\")\n",
|
|
" if isinstance(data, dict):\n",
|
|
" print(f\"🔍 Dict keys: {list(data.keys())}\")\n",
|
|
" elif isinstance(data, list):\n",
|
|
" print(f\"🔍 List length: {len(data)}\")\n",
|
|
" except Exception as e:\n",
|
|
" print(f\"❌ JSON parsing failed: {str(e)}\")\n",
|
|
" \n",
|
|
" except Exception as e:\n",
|
|
" print(f\"❌ Debug failed: {str(e)}\")\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 43,
|
|
"id": "75c90739",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"🧪 Testing the fixed LLM generation...\n",
|
|
"🔄 Generating 5 survey responses with LLM...\n",
|
|
"📊 Using max_tokens: 2000 (estimated: 2000)\n",
|
|
"📝 Raw response length: 2629 characters\n",
|
|
"🔍 Parsed JSON type: <class 'dict'>\n",
|
|
"📊 Found data in 'responses': 5 rows\n",
|
|
"✅ Successfully generated 5 survey responses\n",
|
|
"\n",
|
|
"📊 Generated dataset shape: (5, 18)\n",
|
|
"\n",
|
|
"📋 First few rows:\n",
|
|
" response_id respondent_id submitted_at \\\n",
|
|
"0 d8b1c6f3-6f7a-4b4f-9c5f-3a5f8b6e2f1e 12345 2023-10-01 14:30:00 \n",
|
|
"1 f3a8e3c1-9b4e-4e5e-9c2b-8f5e3c9b1f3d 67890 2023-10-01 15:00:00 \n",
|
|
"2 c9c8e3f1-2b4f-4a6c-8c2e-2a5f3c8e1f2b 54321 2023-10-01 16:15:00 \n",
|
|
"3 a5b3c6d2-1e4f-4c5e-9a1f-1f6a7b8e3c9f 98765 2023-10-01 17:45:00 \n",
|
|
"4 b8f4c3e2-2e4f-4c5e-8a2f-4c5e3b8e2f1a 13579 2023-10-01 18:30:00 \n",
|
|
"\n",
|
|
" country language device age gender education income_band \\\n",
|
|
"0 KE en android 29 female bachelor upper_mid \n",
|
|
"1 UG sw web 34 male diploma lower_mid \n",
|
|
"2 TZ en ios 42 nonbinary postgraduate high \n",
|
|
"3 RW sw android 27 female secondary low \n",
|
|
"4 NG en web 55 male bachelor upper_mid \n",
|
|
"\n",
|
|
" completion_seconds attention_passed q_quality q_value q_ease \\\n",
|
|
"0 420.0 True 5 4 4 \n",
|
|
"1 600.0 True 3 3 2 \n",
|
|
"2 300.5 True 4 5 4 \n",
|
|
"3 720.0 False 2 3 3 \n",
|
|
"4 540.0 True 5 5 5 \n",
|
|
"\n",
|
|
" q_support nps is_detractor \n",
|
|
"0 5 9 False \n",
|
|
"1 4 5 False \n",
|
|
"2 5 10 False \n",
|
|
"3 2 3 True \n",
|
|
"4 5 8 False \n",
|
|
"\n",
|
|
"📈 Data types:\n",
|
|
"response_id object\n",
|
|
"respondent_id int64\n",
|
|
"submitted_at object\n",
|
|
"country object\n",
|
|
"language object\n",
|
|
"device object\n",
|
|
"age int64\n",
|
|
"gender object\n",
|
|
"education object\n",
|
|
"income_band object\n",
|
|
"completion_seconds float64\n",
|
|
"attention_passed bool\n",
|
|
"q_quality int64\n",
|
|
"q_value int64\n",
|
|
"q_ease int64\n",
|
|
"q_support int64\n",
|
|
"nps int64\n",
|
|
"is_detractor bool\n",
|
|
"dtype: object\n",
|
|
"\n",
|
|
"✅ SUCCESS! LLM generation is now working!\n",
|
|
"📊 Generated 5 survey responses using LLM\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"# Test the fixed implementation\n",
|
|
"print(\"🧪 Testing the fixed LLM generation...\")\n",
|
|
"\n",
|
|
"# Test with small dataset\n",
|
|
"test_df = fixed_llm_generate_batch(CFG, 5)\n",
|
|
"print(f\"\\n📊 Generated dataset shape: {test_df.shape}\")\n",
|
|
"print(f\"\\n📋 First few rows:\")\n",
|
|
"print(test_df.head())\n",
|
|
"print(f\"\\n📈 Data types:\")\n",
|
|
"print(test_df.dtypes)\n",
|
|
"\n",
|
|
"if not test_df.empty:\n",
|
|
" print(f\"\\n✅ SUCCESS! LLM generation is now working!\")\n",
|
|
" print(f\"📊 Generated {len(test_df)} survey responses using LLM\")\n",
|
|
"else:\n",
|
|
" print(f\"\\n❌ Still having issues with LLM generation\")\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 44,
|
|
"id": "dd83b842",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"🚀 Testing larger dataset generation...\n",
|
|
"🚀 Generating 100 survey responses with adaptive batching\n",
|
|
"📊 Using optimal batch size: 10\n",
|
|
"\n",
|
|
"📦 Processing batch: 10 rows (remaining: 100)\n",
|
|
"🔄 Generating 10 survey responses with LLM...\n",
|
|
"📊 Using max_tokens: 3500 (estimated: 3500)\n",
|
|
"📝 Raw response length: 5238 characters\n",
|
|
"🔍 Parsed JSON type: <class 'dict'>\n",
|
|
"📊 Found data in 'responses': 10 rows\n",
|
|
"✅ Successfully generated 10 survey responses\n",
|
|
"\n",
|
|
"📦 Processing batch: 10 rows (remaining: 90)\n",
|
|
"🔄 Generating 10 survey responses with LLM...\n",
|
|
"📊 Using max_tokens: 3500 (estimated: 3500)\n",
|
|
"📝 Raw response length: 5235 characters\n",
|
|
"🔍 Parsed JSON type: <class 'dict'>\n",
|
|
"📊 Found data in 'responses': 10 rows\n",
|
|
"✅ Successfully generated 10 survey responses\n",
|
|
"\n",
|
|
"📦 Processing batch: 10 rows (remaining: 80)\n",
|
|
"🔄 Generating 10 survey responses with LLM...\n",
|
|
"📊 Using max_tokens: 3500 (estimated: 3500)\n",
|
|
"📝 Raw response length: 5232 characters\n",
|
|
"🔍 Parsed JSON type: <class 'dict'>\n",
|
|
"📊 Found data in 'responses': 10 rows\n",
|
|
"✅ Successfully generated 10 survey responses\n",
|
|
"\n",
|
|
"📦 Processing batch: 10 rows (remaining: 70)\n",
|
|
"🔄 Generating 10 survey responses with LLM...\n",
|
|
"📊 Using max_tokens: 3500 (estimated: 3500)\n",
|
|
"📝 Raw response length: 5239 characters\n",
|
|
"🔍 Parsed JSON type: <class 'dict'>\n",
|
|
"📊 Found data in 'responses': 10 rows\n",
|
|
"✅ Successfully generated 10 survey responses\n",
|
|
"\n",
|
|
"📦 Processing batch: 10 rows (remaining: 60)\n",
|
|
"🔄 Generating 10 survey responses with LLM...\n",
|
|
"📊 Using max_tokens: 3500 (estimated: 3500)\n",
|
|
"📝 Raw response length: 5238 characters\n",
|
|
"🔍 Parsed JSON type: <class 'dict'>\n",
|
|
"📊 Found data in 'responses': 10 rows\n",
|
|
"✅ Successfully generated 10 survey responses\n",
|
|
"\n",
|
|
"📦 Processing batch: 10 rows (remaining: 50)\n",
|
|
"🔄 Generating 10 survey responses with LLM...\n",
|
|
"📊 Using max_tokens: 3500 (estimated: 3500)\n",
|
|
"📝 Raw response length: 5236 characters\n",
|
|
"🔍 Parsed JSON type: <class 'dict'>\n",
|
|
"📊 Found data in 'responses': 10 rows\n",
|
|
"✅ Successfully generated 10 survey responses\n",
|
|
"\n",
|
|
"📦 Processing batch: 10 rows (remaining: 40)\n",
|
|
"🔄 Generating 10 survey responses with LLM...\n",
|
|
"📊 Using max_tokens: 3500 (estimated: 3500)\n",
|
|
"📝 Raw response length: 5229 characters\n",
|
|
"🔍 Parsed JSON type: <class 'dict'>\n",
|
|
"📊 Found data in 'responses': 10 rows\n",
|
|
"✅ Successfully generated 10 survey responses\n",
|
|
"\n",
|
|
"📦 Processing batch: 10 rows (remaining: 30)\n",
|
|
"🔄 Generating 10 survey responses with LLM...\n",
|
|
"📊 Using max_tokens: 3500 (estimated: 3500)\n",
|
|
"📝 Raw response length: 5244 characters\n",
|
|
"🔍 Parsed JSON type: <class 'dict'>\n",
|
|
"📊 Found data in 'responses': 10 rows\n",
|
|
"✅ Successfully generated 10 survey responses\n",
|
|
"\n",
|
|
"📦 Processing batch: 10 rows (remaining: 20)\n",
|
|
"🔄 Generating 10 survey responses with LLM...\n",
|
|
"📊 Using max_tokens: 3500 (estimated: 3500)\n",
|
|
"📝 Raw response length: 5234 characters\n",
|
|
"🔍 Parsed JSON type: <class 'dict'>\n",
|
|
"📊 Found data in 'responses': 10 rows\n",
|
|
"✅ Successfully generated 10 survey responses\n",
|
|
"\n",
|
|
"📦 Processing batch: 10 rows (remaining: 10)\n",
|
|
"🔄 Generating 10 survey responses with LLM...\n",
|
|
"📊 Using max_tokens: 3500 (estimated: 3500)\n",
|
|
"📝 Raw response length: 5238 characters\n",
|
|
"🔍 Parsed JSON type: <class 'dict'>\n",
|
|
"📊 Found data in 'responses': 10 rows\n",
|
|
"✅ Successfully generated 10 survey responses\n",
|
|
"✅ Generated total: 100 survey responses\n",
|
|
"\n",
|
|
"📊 Large dataset shape: (100, 18)\n",
|
|
"\n",
|
|
"📈 Summary statistics:\n",
|
|
" respondent_id age completion_seconds q_quality q_value \\\n",
|
|
"count 100.000000 100.000000 100.000000 100.000000 100.000000 \n",
|
|
"mean 33513.700000 34.070000 588.525000 3.740000 3.910000 \n",
|
|
"std 29233.800863 7.835757 230.530212 1.001211 0.995901 \n",
|
|
"min 10001.000000 22.000000 120.500000 2.000000 2.000000 \n",
|
|
"25% 10009.000000 28.000000 420.375000 3.000000 3.000000 \n",
|
|
"50% 15122.500000 33.000000 600.000000 4.000000 4.000000 \n",
|
|
"75% 55955.750000 39.250000 720.000000 5.000000 5.000000 \n",
|
|
"max 98765.000000 50.000000 1500.000000 5.000000 5.000000 \n",
|
|
"\n",
|
|
" q_ease q_support nps \n",
|
|
"count 100.000000 100.000000 100.000000 \n",
|
|
"mean 3.900000 3.910000 6.990000 \n",
|
|
"std 0.937437 0.985706 2.333312 \n",
|
|
"min 2.000000 2.000000 2.000000 \n",
|
|
"25% 3.000000 3.000000 5.000000 \n",
|
|
"50% 4.000000 4.000000 7.000000 \n",
|
|
"75% 5.000000 5.000000 9.000000 \n",
|
|
"max 5.000000 5.000000 10.000000 \n",
|
|
"💾 Saved: data\\survey_llm_fixed_20251023T005139Z.csv\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"C:\\Users\\Joshua\\AppData\\Local\\Temp\\ipykernel_27572\\2716383900.py:12: DeprecationWarning: datetime.datetime.utcnow() is deprecated and scheduled for removal in a future version. Use timezone-aware objects to represent datetimes in UTC: datetime.datetime.now(datetime.UTC).\n",
|
|
" ts = datetime.utcnow().strftime(\"%Y%m%dT%H%M%SZ\")\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"#Test larger dataset generation \n",
|
|
"print(\"🚀 Testing larger dataset generation...\")\n",
|
|
"large_df = fixed_generate_llm(CFG, total_rows=100, batch_size=25)\n",
|
|
"if not large_df.empty:\n",
|
|
" print(f\"\\n📊 Large dataset shape: {large_df.shape}\")\n",
|
|
" print(f\"\\n📈 Summary statistics:\")\n",
|
|
" print(large_df.describe())\n",
|
|
" \n",
|
|
" # Save the results\n",
|
|
" from pathlib import Path\n",
|
|
" out = Path(\"data\"); out.mkdir(exist_ok=True)\n",
|
|
" ts = datetime.utcnow().strftime(\"%Y%m%dT%H%M%SZ\")\n",
|
|
" csv_path = out / f\"survey_llm_fixed_{ts}.csv\"\n",
|
|
" large_df.to_csv(csv_path, index=False)\n",
|
|
" print(f\"💾 Saved: {csv_path}\")\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "6029d3e2",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"LLM available: True\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"\n",
|
|
"def build_json_schema(CFG):\n",
|
|
" schema = {'type':'array','items':{'type':'object','properties':{},'required':[]}}\n",
|
|
" props = schema['items']['properties']; req = schema['items']['required']\n",
|
|
" for f in CFG['fields']:\n",
|
|
" name, t = f['name'], f['type']\n",
|
|
" req.append(name)\n",
|
|
" if t in ('int','float'): props[name] = {'type':'number' if t=='float' else 'integer'}\n",
|
|
" elif t == 'enum': props[name] = {'type':'string','enum': f['values']}\n",
|
|
" elif t in ('uuid4','datetime'): props[name] = {'type':'string'}\n",
|
|
" elif t == 'bool': props[name] = {'type':'boolean'}\n",
|
|
" else: props[name] = {'type':'string'}\n",
|
|
" return schema\n",
|
|
"\n",
|
|
"PROMPT_PREAMBLE = (\n",
|
|
" \"You are a data generator. Return ONLY JSON. \"\n",
|
|
" \"Respond as a JSON object with key 'rows' whose value is an array of exactly N objects. \"\n",
|
|
" \"No prose, no code fences, no trailing commas.\"\n",
|
|
")\n",
|
|
"\n",
|
|
"def render_prompt(CFG, n_rows=100):\n",
|
|
" minimal_cfg = {'fields': []}\n",
|
|
" for f in CFG['fields']:\n",
|
|
" base = {k: f[k] for k in ['name','type'] if k in f}\n",
|
|
" if 'min' in f and 'max' in f: base.update({'min': f['min'], 'max': f['max']})\n",
|
|
" if 'values' in f: base.update({'values': f['values']})\n",
|
|
" if 'fmt' in f: base.update({'fmt': f['fmt']})\n",
|
|
" minimal_cfg['fields'].append(base)\n",
|
|
" return {\n",
|
|
" 'preamble': PROMPT_PREAMBLE,\n",
|
|
" 'n_rows': n_rows,\n",
|
|
" 'schema': build_json_schema(CFG),\n",
|
|
" 'constraints': minimal_cfg,\n",
|
|
" 'instruction': f\"Return ONLY this structure: {{'rows': [ ... exactly {n_rows} objects ... ]}}\"\n",
|
|
" }\n",
|
|
"\n",
|
|
"def parse_llm_json_to_df(raw: str) -> pd.DataFrame:\n",
|
|
" try:\n",
|
|
" obj = json.loads(raw)\n",
|
|
" if isinstance(obj, dict) and isinstance(obj.get('rows'), list):\n",
|
|
" return pd.DataFrame(obj['rows'])\n",
|
|
" except Exception:\n",
|
|
" pass\n",
|
|
" data = extract_strict_json(raw)\n",
|
|
" return pd.DataFrame(data)\n",
|
|
"\n",
|
|
"USE_LLM = bool(os.getenv('OPENAI_API_KEY'))\n",
|
|
"print('LLM available:', USE_LLM)\n",
|
|
"\n",
|
|
"def llm_generate_batch(CFG, n_rows=50):\n",
|
|
" if USE_LLM:\n",
|
|
" try:\n",
|
|
" from openai import OpenAI\n",
|
|
" client = OpenAI()\n",
|
|
" prompt = json.dumps(render_prompt(CFG, n_rows))\n",
|
|
" resp = client.chat.completions.create(\n",
|
|
" model='gpt-4o-mini',\n",
|
|
" response_format={'type': 'json_object'},\n",
|
|
" messages=[\n",
|
|
" {'role':'system','content':'You output strict JSON only.'},\n",
|
|
" {'role':'user','content': prompt}\n",
|
|
" ],\n",
|
|
" temperature=0.2,\n",
|
|
" max_tokens=8192,\n",
|
|
" )\n",
|
|
" raw = resp.choices[0].message.content\n",
|
|
" try:\n",
|
|
" return parse_llm_json_to_df(raw)\n",
|
|
" except Exception:\n",
|
|
" stricter = (\n",
|
|
" prompt\n",
|
|
" + \"\\nReturn ONLY a JSON object structured as: \"\n",
|
|
" + \"{\\\"rows\\\": [ ... exactly N objects ... ]}. \"\n",
|
|
" + \"No prose, no explanations.\"\n",
|
|
" )\n",
|
|
" resp2 = client.chat.completions.create(\n",
|
|
" model='gpt-4o-mini',\n",
|
|
" response_format={'type': 'json_object'},\n",
|
|
" messages=[\n",
|
|
" {'role':'system','content':'You output strict JSON only.'},\n",
|
|
" {'role':'user','content': stricter}\n",
|
|
" ],\n",
|
|
" temperature=0.2,\n",
|
|
" max_tokens=8192,\n",
|
|
" )\n",
|
|
" raw2 = resp2.choices[0].message.content\n",
|
|
" return parse_llm_json_to_df(raw2)\n",
|
|
" except Exception as e:\n",
|
|
" print('LLM error, fallback to rule-based mock:', e)\n",
|
|
" tmp = dict(CFG); tmp['rows'] = n_rows\n",
|
|
" return generate_rule_based(tmp)\n",
|
|
"\n",
|
|
"def generate_llm(CFG, total_rows=200, batch_size=50):\n",
|
|
" dfs = []; remaining = total_rows\n",
|
|
" while remaining > 0:\n",
|
|
" b = min(batch_size, remaining)\n",
|
|
" dfs.append(llm_generate_batch(CFG, n_rows=b))\n",
|
|
" remaining -= b\n",
|
|
" time.sleep(0.2)\n",
|
|
" return pd.concat(dfs, ignore_index=True)\n",
|
|
"\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 9,
|
|
"id": "2e759087",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"LLM error, fallback to rule-based mock: No JSON array found in model output.\n"
|
|
]
|
|
},
|
|
{
|
|
"data": {
|
|
"text/html": [
|
|
"<div>\n",
|
|
"<style scoped>\n",
|
|
" .dataframe tbody tr th:only-of-type {\n",
|
|
" vertical-align: middle;\n",
|
|
" }\n",
|
|
"\n",
|
|
" .dataframe tbody tr th {\n",
|
|
" vertical-align: top;\n",
|
|
" }\n",
|
|
"\n",
|
|
" .dataframe thead th {\n",
|
|
" text-align: right;\n",
|
|
" }\n",
|
|
"</style>\n",
|
|
"<table border=\"1\" class=\"dataframe\">\n",
|
|
" <thead>\n",
|
|
" <tr style=\"text-align: right;\">\n",
|
|
" <th></th>\n",
|
|
" <th>response_id</th>\n",
|
|
" <th>respondent_id</th>\n",
|
|
" <th>submitted_at</th>\n",
|
|
" <th>country</th>\n",
|
|
" <th>language</th>\n",
|
|
" <th>device</th>\n",
|
|
" <th>age</th>\n",
|
|
" <th>gender</th>\n",
|
|
" <th>education</th>\n",
|
|
" <th>income_band</th>\n",
|
|
" <th>completion_seconds</th>\n",
|
|
" <th>attention_passed</th>\n",
|
|
" <th>q_quality</th>\n",
|
|
" <th>q_value</th>\n",
|
|
" <th>q_ease</th>\n",
|
|
" <th>q_support</th>\n",
|
|
" <th>nps</th>\n",
|
|
" <th>is_detractor</th>\n",
|
|
" </tr>\n",
|
|
" </thead>\n",
|
|
" <tbody>\n",
|
|
" <tr>\n",
|
|
" <th>0</th>\n",
|
|
" <td>9e7811bd-27ee-4b7c-9b7a-c98441e337f0</td>\n",
|
|
" <td>40160</td>\n",
|
|
" <td>2024-08-18 19:10:06</td>\n",
|
|
" <td>KE</td>\n",
|
|
" <td>sw</td>\n",
|
|
" <td>web</td>\n",
|
|
" <td>28</td>\n",
|
|
" <td>male</td>\n",
|
|
" <td>secondary</td>\n",
|
|
" <td>lower_mid</td>\n",
|
|
" <td>1800.000000</td>\n",
|
|
" <td>True</td>\n",
|
|
" <td>4</td>\n",
|
|
" <td>3</td>\n",
|
|
" <td>3</td>\n",
|
|
" <td>3</td>\n",
|
|
" <td>4</td>\n",
|
|
" <td>True</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <th>1</th>\n",
|
|
" <td>85ec8b90-5468-4880-8309-e325da14d877</td>\n",
|
|
" <td>55381</td>\n",
|
|
" <td>2025-01-24 12:21:13</td>\n",
|
|
" <td>TZ</td>\n",
|
|
" <td>sw</td>\n",
|
|
" <td>ios</td>\n",
|
|
" <td>23</td>\n",
|
|
" <td>female</td>\n",
|
|
" <td>bachelor</td>\n",
|
|
" <td>high</td>\n",
|
|
" <td>431.412783</td>\n",
|
|
" <td>True</td>\n",
|
|
" <td>3</td>\n",
|
|
" <td>2</td>\n",
|
|
" <td>3</td>\n",
|
|
" <td>4</td>\n",
|
|
" <td>4</td>\n",
|
|
" <td>False</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <th>2</th>\n",
|
|
" <td>498dff10-040f-4206-8170-dfce0d5a69f0</td>\n",
|
|
" <td>48338</td>\n",
|
|
" <td>2025-07-15 22:21:54</td>\n",
|
|
" <td>TZ</td>\n",
|
|
" <td>en</td>\n",
|
|
" <td>ios</td>\n",
|
|
" <td>49</td>\n",
|
|
" <td>male</td>\n",
|
|
" <td>bachelor</td>\n",
|
|
" <td>low</td>\n",
|
|
" <td>1800.000000</td>\n",
|
|
" <td>True</td>\n",
|
|
" <td>2</td>\n",
|
|
" <td>3</td>\n",
|
|
" <td>3</td>\n",
|
|
" <td>1</td>\n",
|
|
" <td>3</td>\n",
|
|
" <td>False</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <th>3</th>\n",
|
|
" <td>ddf11d94-5d6e-4322-9811-4e763f5ed46b</td>\n",
|
|
" <td>59925</td>\n",
|
|
" <td>2025-01-27 00:16:57</td>\n",
|
|
" <td>KE</td>\n",
|
|
" <td>en</td>\n",
|
|
" <td>web</td>\n",
|
|
" <td>22</td>\n",
|
|
" <td>male</td>\n",
|
|
" <td>bachelor</td>\n",
|
|
" <td>upper_mid</td>\n",
|
|
" <td>656.050991</td>\n",
|
|
" <td>True</td>\n",
|
|
" <td>4</td>\n",
|
|
" <td>4</td>\n",
|
|
" <td>1</td>\n",
|
|
" <td>3</td>\n",
|
|
" <td>5</td>\n",
|
|
" <td>False</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <th>4</th>\n",
|
|
" <td>2ef22a0c-fd13-4798-9276-f43831b8f7bc</td>\n",
|
|
" <td>68993</td>\n",
|
|
" <td>2024-08-19 04:21:49</td>\n",
|
|
" <td>KE</td>\n",
|
|
" <td>en</td>\n",
|
|
" <td>android</td>\n",
|
|
" <td>40</td>\n",
|
|
" <td>male</td>\n",
|
|
" <td>secondary</td>\n",
|
|
" <td>lower_mid</td>\n",
|
|
" <td>1553.938944</td>\n",
|
|
" <td>True</td>\n",
|
|
" <td>2</td>\n",
|
|
" <td>2</td>\n",
|
|
" <td>5</td>\n",
|
|
" <td>1</td>\n",
|
|
" <td>5</td>\n",
|
|
" <td>False</td>\n",
|
|
" </tr>\n",
|
|
" </tbody>\n",
|
|
"</table>\n",
|
|
"</div>"
|
|
],
|
|
"text/plain": [
|
|
" response_id respondent_id submitted_at \\\n",
|
|
"0 9e7811bd-27ee-4b7c-9b7a-c98441e337f0 40160 2024-08-18 19:10:06 \n",
|
|
"1 85ec8b90-5468-4880-8309-e325da14d877 55381 2025-01-24 12:21:13 \n",
|
|
"2 498dff10-040f-4206-8170-dfce0d5a69f0 48338 2025-07-15 22:21:54 \n",
|
|
"3 ddf11d94-5d6e-4322-9811-4e763f5ed46b 59925 2025-01-27 00:16:57 \n",
|
|
"4 2ef22a0c-fd13-4798-9276-f43831b8f7bc 68993 2024-08-19 04:21:49 \n",
|
|
"\n",
|
|
" country language device age gender education income_band \\\n",
|
|
"0 KE sw web 28 male secondary lower_mid \n",
|
|
"1 TZ sw ios 23 female bachelor high \n",
|
|
"2 TZ en ios 49 male bachelor low \n",
|
|
"3 KE en web 22 male bachelor upper_mid \n",
|
|
"4 KE en android 40 male secondary lower_mid \n",
|
|
"\n",
|
|
" completion_seconds attention_passed q_quality q_value q_ease \\\n",
|
|
"0 1800.000000 True 4 3 3 \n",
|
|
"1 431.412783 True 3 2 3 \n",
|
|
"2 1800.000000 True 2 3 3 \n",
|
|
"3 656.050991 True 4 4 1 \n",
|
|
"4 1553.938944 True 2 2 5 \n",
|
|
"\n",
|
|
" q_support nps is_detractor \n",
|
|
"0 3 4 True \n",
|
|
"1 4 4 False \n",
|
|
"2 1 3 False \n",
|
|
"3 3 5 False \n",
|
|
"4 1 5 False "
|
|
]
|
|
},
|
|
"execution_count": 9,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"df_llm = generate_llm(CFG, total_rows=100, batch_size=50)\n",
|
|
"df_llm.head()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 46,
|
|
"id": "6d4908ad",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"🧪 Testing improved LLM generation with adaptive batching...\n",
|
|
"\n",
|
|
"📦 Testing small batch (10 rows)...\n",
|
|
"🔄 Generating 10 survey responses with LLM...\n",
|
|
"📊 Using max_tokens: 3500 (estimated: 3500)\n",
|
|
"📝 Raw response length: 5233 characters\n",
|
|
"🔍 Parsed JSON type: <class 'dict'>\n",
|
|
"📊 Found data in 'responses': 10 rows\n",
|
|
"✅ Successfully generated 10 survey responses\n",
|
|
"✅ Small batch result: 10 rows\n",
|
|
"\n",
|
|
"📦 Testing medium dataset (30 rows) with adaptive batching...\n",
|
|
"🚀 Generating 30 survey responses with adaptive batching\n",
|
|
"📊 Using optimal batch size: 15\n",
|
|
"\n",
|
|
"📦 Processing batch: 15 rows (remaining: 30)\n",
|
|
"🔄 Generating 15 survey responses with LLM...\n",
|
|
"📊 Using max_tokens: 5000 (estimated: 5000)\n",
|
|
"📝 Raw response length: 7839 characters\n",
|
|
"🔍 Parsed JSON type: <class 'dict'>\n",
|
|
"📊 Found data in 'responses': 15 rows\n",
|
|
"✅ Successfully generated 15 survey responses\n",
|
|
"\n",
|
|
"📦 Processing batch: 15 rows (remaining: 15)\n",
|
|
"🔄 Generating 15 survey responses with LLM...\n",
|
|
"📊 Using max_tokens: 5000 (estimated: 5000)\n",
|
|
"📝 Raw response length: 7841 characters\n",
|
|
"🔍 Parsed JSON type: <class 'dict'>\n",
|
|
"📊 Found data in 'responses': 15 rows\n",
|
|
"✅ Successfully generated 15 survey responses\n",
|
|
"✅ Generated total: 30 survey responses\n",
|
|
"✅ Medium dataset result: 30 rows\n",
|
|
"\n",
|
|
"📊 Dataset shape: (30, 18)\n",
|
|
"\n",
|
|
"📋 First few rows:\n",
|
|
" response_id respondent_id submitted_at \\\n",
|
|
"0 d1e5c4a3-4b1f-4f6b-8f9e-9f1e1f2e3d4c 10001 2023-10-01 14:30:00 \n",
|
|
"1 c2b1d4a6-7f8e-4c5c-9d8f-1e2c3b4a5e6f 10002 2023-10-01 15:00:00 \n",
|
|
"2 e3f2c5b7-8a2d-4c8e-9f1b-2c3d4e5f6a7b 10003 2023-10-01 15:30:00 \n",
|
|
"3 f4a5b6c8-9d3e-4b1f-9f2c-3d4e5f6a7b8c 10004 2023-10-01 16:00:00 \n",
|
|
"4 g5b6c7d9-0e4f-4b2a-8f3d-4e5f6a7b8c9d 10005 2023-10-01 16:30:00 \n",
|
|
"\n",
|
|
" country language device age gender education income_band \\\n",
|
|
"0 KE en android 28 female bachelor upper_mid \n",
|
|
"1 UG sw web 35 male diploma lower_mid \n",
|
|
"2 TZ en ios 42 nonbinary postgraduate high \n",
|
|
"3 RW sw web 29 female secondary upper_mid \n",
|
|
"4 NG en android 50 male bachelor high \n",
|
|
"\n",
|
|
" completion_seconds attention_passed q_quality q_value q_ease \\\n",
|
|
"0 450.0 True 5 4 5 \n",
|
|
"1 600.0 True 3 2 4 \n",
|
|
"2 720.0 True 4 5 4 \n",
|
|
"3 300.0 True 3 3 3 \n",
|
|
"4 540.0 True 5 5 5 \n",
|
|
"\n",
|
|
" q_support nps is_detractor \n",
|
|
"0 4 9 False \n",
|
|
"1 3 5 False \n",
|
|
"2 5 10 False \n",
|
|
"3 4 6 False \n",
|
|
"4 5 10 False \n",
|
|
"💾 Saved: data\\survey_adaptive_batch_20251023T005927Z.csv\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"C:\\Users\\Joshua\\AppData\\Local\\Temp\\ipykernel_27572\\1770033334.py:22: DeprecationWarning: datetime.datetime.utcnow() is deprecated and scheduled for removal in a future version. Use timezone-aware objects to represent datetimes in UTC: datetime.datetime.now(datetime.UTC).\n",
|
|
" ts = datetime.utcnow().strftime(\"%Y%m%dT%H%M%SZ\")\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"# Test the improved LLM generation with adaptive batching\n",
|
|
"print(\"🧪 Testing improved LLM generation with adaptive batching...\")\n",
|
|
"\n",
|
|
"# Test with smaller dataset first\n",
|
|
"print(\"\\n📦 Testing small batch (10 rows)...\")\n",
|
|
"small_df = fixed_llm_generate_batch(CFG, 10)\n",
|
|
"print(f\"✅ Small batch result: {len(small_df)} rows\")\n",
|
|
"\n",
|
|
"# Test with medium dataset using adaptive batching\n",
|
|
"print(\"\\n📦 Testing medium dataset (30 rows) with adaptive batching...\")\n",
|
|
"medium_df = fixed_generate_llm(CFG, total_rows=30, batch_size=15)\n",
|
|
"print(f\"✅ Medium dataset result: {len(medium_df)} rows\")\n",
|
|
"\n",
|
|
"if not medium_df.empty:\n",
|
|
" print(f\"\\n📊 Dataset shape: {medium_df.shape}\")\n",
|
|
" print(f\"\\n📋 First few rows:\")\n",
|
|
" print(medium_df.head())\n",
|
|
" \n",
|
|
" # Save the results\n",
|
|
" from pathlib import Path\n",
|
|
" out = Path(\"data\"); out.mkdir(exist_ok=True)\n",
|
|
" ts = datetime.utcnow().strftime(\"%Y%m%dT%H%M%SZ\")\n",
|
|
" csv_path = out / f\"survey_adaptive_batch_{ts}.csv\"\n",
|
|
" medium_df.to_csv(csv_path, index=False)\n",
|
|
" print(f\"💾 Saved: {csv_path}\")\n",
|
|
"else:\n",
|
|
" print(\"❌ Medium dataset generation failed\")\n"
|
|
]
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": ".venv",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.12.12"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 5
|
|
}
|