Files
LLM_Engineering_OLD/week3/community-contributions/week3_Exercise_survey_Dataset_Generation.ipynb
The Top Dev 308df9e3bc Add Week 3 Survey Dataset Generator with improved LLM generation
- Fixed LLM generation issues with adaptive batching
- Added JSON repair mechanism for truncated responses
- Implemented retry logic with smaller batch sizes
- Enhanced error handling and fallback mechanisms
- Successfully generates realistic survey data using LLM
2025-10-23 04:05:05 +03:00

1898 lines
78 KiB
Plaintext

{
"cells": [
{
"cell_type": "markdown",
"id": "a8dbb4e8",
"metadata": {},
"source": [
"# 🧪 Survey Synthetic Dataset Generator — Week 3 Task"
]
},
{
"cell_type": "code",
"execution_count": 34,
"id": "8d86f629",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"✅ Base libraries ready. Pandera available: True\n"
]
}
],
"source": [
"\n",
"import os, re, json, time, uuid, math, random\n",
"from datetime import datetime, timedelta\n",
"from typing import List, Dict, Any\n",
"import numpy as np, pandas as pd\n",
"import pandera.pandas as pa\n",
"random.seed(7); np.random.seed(7)\n",
"print(\"✅ Base libraries ready. Pandera available:\", pa is not None)\n"
]
},
{
"cell_type": "code",
"execution_count": 23,
"id": "f196ae73",
"metadata": {},
"outputs": [],
"source": [
"\n",
"def extract_strict_json(text: str):\n",
" \"\"\"Improved JSON extraction with multiple fallback strategies\"\"\"\n",
" if text is None:\n",
" raise ValueError(\"Empty model output.\")\n",
" \n",
" t = text.strip()\n",
" \n",
" # Strategy 1: Direct JSON parsing\n",
" try:\n",
" obj = json.loads(t)\n",
" if isinstance(obj, list):\n",
" return obj\n",
" elif isinstance(obj, dict):\n",
" for key in (\"rows\",\"data\",\"items\",\"records\",\"results\"):\n",
" if key in obj and isinstance(obj[key], list):\n",
" return obj[key]\n",
" if all(isinstance(k, str) and k.isdigit() for k in obj.keys()):\n",
" return [obj[k] for k in sorted(obj.keys(), key=int)]\n",
" except json.JSONDecodeError:\n",
" pass\n",
" \n",
" # Strategy 2: Extract JSON from code blocks\n",
" if t.startswith(\"```\"):\n",
" t = re.sub(r\"^```(?:json)?\\s*|\\s*```$\", \"\", t, flags=re.IGNORECASE|re.MULTILINE).strip()\n",
" \n",
" # Strategy 3: Find JSON array in text\n",
" start, end = t.find('['), t.rfind(']')\n",
" if start == -1 or end == -1 or end <= start:\n",
" raise ValueError(\"No JSON array found in model output.\")\n",
" \n",
" t = t[start:end+1]\n",
" \n",
" # Strategy 4: Fix common JSON issues\n",
" t = re.sub(r\",\\s*([\\]}])\", r\"\\1\", t) # Remove trailing commas\n",
" t = re.sub(r\"\\bNaN\\b|\\bInfinity\\b|\\b-Infinity\\b\", \"null\", t) # Replace NaN/Infinity\n",
" t = t.replace(\"\\u00a0\", \" \").replace(\"\\u200b\", \"\") # Remove invisible characters\n",
" \n",
" try:\n",
" return json.loads(t)\n",
" except json.JSONDecodeError as e:\n",
" raise ValueError(f\"Could not parse JSON: {str(e)}. Text: {t[:200]}...\")\n"
]
},
{
"cell_type": "markdown",
"id": "3670fa0d",
"metadata": {},
"source": [
"## 1) Configuration"
]
},
{
"cell_type": "code",
"execution_count": 36,
"id": "d16bd03a",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Loaded config for 800 rows and 18 fields.\n"
]
}
],
"source": [
"\n",
"CFG = {\n",
" \"rows\": 800,\n",
" \"datetime_range\": {\"start\": \"2024-01-01\", \"end\": \"2025-10-01\", \"fmt\": \"%Y-%m-%d %H:%M:%S\"},\n",
" \"fields\": [\n",
" {\"name\": \"response_id\", \"type\": \"uuid4\"},\n",
" {\"name\": \"respondent_id\", \"type\": \"int\", \"min\": 10000, \"max\": 99999},\n",
" {\"name\": \"submitted_at\", \"type\": \"datetime\"},\n",
" {\"name\": \"country\", \"type\": \"enum\", \"values\": [\"KE\",\"UG\",\"TZ\",\"RW\",\"NG\",\"ZA\"], \"probs\": [0.50,0.10,0.12,0.05,0.15,0.08]},\n",
" {\"name\": \"language\", \"type\": \"enum\", \"values\": [\"en\",\"sw\"], \"probs\": [0.85,0.15]},\n",
" {\"name\": \"device\", \"type\": \"enum\", \"values\": [\"android\",\"ios\",\"web\"], \"probs\": [0.60,0.25,0.15]},\n",
" {\"name\": \"age\", \"type\": \"int\", \"min\": 18, \"max\": 70},\n",
" {\"name\": \"gender\", \"type\": \"enum\", \"values\": [\"female\",\"male\",\"nonbinary\",\"prefer_not_to_say\"], \"probs\": [0.49,0.49,0.01,0.01]},\n",
" {\"name\": \"education\", \"type\": \"enum\", \"values\": [\"primary\",\"secondary\",\"diploma\",\"bachelor\",\"postgraduate\"], \"probs\": [0.08,0.32,0.18,0.30,0.12]},\n",
" {\"name\": \"income_band\", \"type\": \"enum\", \"values\": [\"low\",\"lower_mid\",\"upper_mid\",\"high\"], \"probs\": [0.28,0.42,0.23,0.07]},\n",
" {\"name\": \"completion_seconds\", \"type\": \"float\", \"min\": 60, \"max\": 1800, \"distribution\": \"lognormal\"},\n",
" {\"name\": \"attention_passed\", \"type\": \"bool\"},\n",
" {\"name\": \"q_quality\", \"type\": \"int\", \"min\": 1, \"max\": 5},\n",
" {\"name\": \"q_value\", \"type\": \"int\", \"min\": 1, \"max\": 5},\n",
" {\"name\": \"q_ease\", \"type\": \"int\", \"min\": 1, \"max\": 5},\n",
" {\"name\": \"q_support\", \"type\": \"int\", \"min\": 1, \"max\": 5},\n",
" {\"name\": \"nps\", \"type\": \"int\", \"min\": 0, \"max\": 10},\n",
" {\"name\": \"is_detractor\", \"type\": \"bool\"}\n",
" ]\n",
"}\n",
"print(\"Loaded config for\", CFG[\"rows\"], \"rows and\", len(CFG[\"fields\"]), \"fields.\")\n"
]
},
{
"cell_type": "markdown",
"id": "7da1f429",
"metadata": {},
"source": [
"## 2) Helpers"
]
},
{
"cell_type": "code",
"execution_count": 37,
"id": "d2f5fdff",
"metadata": {},
"outputs": [],
"source": [
"\n",
"def sample_enum(values, probs=None, size=None):\n",
" values = list(values)\n",
" if probs is None:\n",
" probs = [1.0 / len(values)] * len(values)\n",
" return np.random.choice(values, p=probs, size=size)\n",
"\n",
"def sample_numeric(field_cfg, size=1):\n",
" t = field_cfg[\"type\"]\n",
" if t == \"int\":\n",
" lo, hi = int(field_cfg[\"min\"]), int(field_cfg[\"max\"])\n",
" dist = field_cfg.get(\"distribution\", \"uniform\")\n",
" if dist == \"uniform\":\n",
" return np.random.randint(lo, hi + 1, size=size)\n",
" elif dist == \"normal\":\n",
" mu = (lo + hi) / 2.0\n",
" sigma = (hi - lo) / 6.0\n",
" out = np.random.normal(mu, sigma, size=size)\n",
" return np.clip(out, lo, hi).astype(int)\n",
" else:\n",
" return np.random.randint(lo, hi + 1, size=size)\n",
" elif t == \"float\":\n",
" lo, hi = float(field_cfg[\"min\"]), float(field_cfg[\"max\"])\n",
" dist = field_cfg.get(\"distribution\", \"uniform\")\n",
" if dist == \"uniform\":\n",
" return np.random.uniform(lo, hi, size=size)\n",
" elif dist == \"normal\":\n",
" mu = (lo + hi) / 2.0\n",
" sigma = (hi - lo) / 6.0\n",
" return np.clip(np.random.normal(mu, sigma, size=size), lo, hi)\n",
" elif dist == \"lognormal\":\n",
" mu = math.log(max(1e-3, (lo + hi) / 2.0))\n",
" sigma = 0.75\n",
" out = np.random.lognormal(mu, sigma, size=size)\n",
" return np.clip(out, lo, hi)\n",
" else:\n",
" return np.random.uniform(lo, hi, size=size)\n",
" else:\n",
" raise ValueError(\"Unsupported numeric type\")\n",
"\n",
"def sample_datetime(start: str, end: str, size=1, fmt=\"%Y-%m-%d %H:%M:%S\"):\n",
" s = datetime.fromisoformat(start)\n",
" e = datetime.fromisoformat(end)\n",
" total = int((e - s).total_seconds())\n",
" r = np.random.randint(0, total, size=size)\n",
" return [(s + timedelta(seconds=int(x))).strftime(fmt) for x in r]\n"
]
},
{
"cell_type": "markdown",
"id": "5f24111a",
"metadata": {},
"source": [
"## 3) Rule-based Generator"
]
},
{
"cell_type": "code",
"execution_count": 38,
"id": "cd61330d",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>response_id</th>\n",
" <th>respondent_id</th>\n",
" <th>submitted_at</th>\n",
" <th>country</th>\n",
" <th>language</th>\n",
" <th>device</th>\n",
" <th>age</th>\n",
" <th>gender</th>\n",
" <th>education</th>\n",
" <th>income_band</th>\n",
" <th>completion_seconds</th>\n",
" <th>attention_passed</th>\n",
" <th>q_quality</th>\n",
" <th>q_value</th>\n",
" <th>q_ease</th>\n",
" <th>q_support</th>\n",
" <th>nps</th>\n",
" <th>is_detractor</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>f099c1b6-a4ae-4fb0-ba98-89a81008c424</td>\n",
" <td>71615</td>\n",
" <td>2024-04-13 19:02:44</td>\n",
" <td>ZA</td>\n",
" <td>en</td>\n",
" <td>web</td>\n",
" <td>47</td>\n",
" <td>male</td>\n",
" <td>secondary</td>\n",
" <td>low</td>\n",
" <td>897.995012</td>\n",
" <td>True</td>\n",
" <td>5</td>\n",
" <td>3</td>\n",
" <td>1</td>\n",
" <td>3</td>\n",
" <td>4</td>\n",
" <td>True</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>f2e20ad1-1ed1-4e33-8beb-5dd0ba23715b</td>\n",
" <td>68564</td>\n",
" <td>2024-03-05 23:30:30</td>\n",
" <td>KE</td>\n",
" <td>en</td>\n",
" <td>android</td>\n",
" <td>67</td>\n",
" <td>female</td>\n",
" <td>bachelor</td>\n",
" <td>lower_mid</td>\n",
" <td>935.607966</td>\n",
" <td>True</td>\n",
" <td>1</td>\n",
" <td>5</td>\n",
" <td>2</td>\n",
" <td>3</td>\n",
" <td>5</td>\n",
" <td>False</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>a9345f69-be75-46b9-8cd3-a276ce0a66bd</td>\n",
" <td>59689</td>\n",
" <td>2024-11-10 03:38:07</td>\n",
" <td>RW</td>\n",
" <td>sw</td>\n",
" <td>android</td>\n",
" <td>23</td>\n",
" <td>male</td>\n",
" <td>bachelor</td>\n",
" <td>low</td>\n",
" <td>1431.517701</td>\n",
" <td>True</td>\n",
" <td>5</td>\n",
" <td>2</td>\n",
" <td>5</td>\n",
" <td>5</td>\n",
" <td>7</td>\n",
" <td>False</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>b4fa8625-d153-4465-ad73-1c4a48eed2f1</td>\n",
" <td>20742</td>\n",
" <td>2024-11-19 17:40:58</td>\n",
" <td>KE</td>\n",
" <td>en</td>\n",
" <td>ios</td>\n",
" <td>68</td>\n",
" <td>female</td>\n",
" <td>secondary</td>\n",
" <td>upper_mid</td>\n",
" <td>448.519416</td>\n",
" <td>True</td>\n",
" <td>5</td>\n",
" <td>5</td>\n",
" <td>5</td>\n",
" <td>3</td>\n",
" <td>10</td>\n",
" <td>False</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>e0ad4bbc-b576-4913-8786-302f06b5e9f7</td>\n",
" <td>63459</td>\n",
" <td>2024-07-28 04:23:37</td>\n",
" <td>KE</td>\n",
" <td>en</td>\n",
" <td>ios</td>\n",
" <td>34</td>\n",
" <td>male</td>\n",
" <td>secondary</td>\n",
" <td>low</td>\n",
" <td>1179.970734</td>\n",
" <td>True</td>\n",
" <td>3</td>\n",
" <td>1</td>\n",
" <td>3</td>\n",
" <td>3</td>\n",
" <td>5</td>\n",
" <td>False</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" response_id respondent_id submitted_at \\\n",
"0 f099c1b6-a4ae-4fb0-ba98-89a81008c424 71615 2024-04-13 19:02:44 \n",
"1 f2e20ad1-1ed1-4e33-8beb-5dd0ba23715b 68564 2024-03-05 23:30:30 \n",
"2 a9345f69-be75-46b9-8cd3-a276ce0a66bd 59689 2024-11-10 03:38:07 \n",
"3 b4fa8625-d153-4465-ad73-1c4a48eed2f1 20742 2024-11-19 17:40:58 \n",
"4 e0ad4bbc-b576-4913-8786-302f06b5e9f7 63459 2024-07-28 04:23:37 \n",
"\n",
" country language device age gender education income_band \\\n",
"0 ZA en web 47 male secondary low \n",
"1 KE en android 67 female bachelor lower_mid \n",
"2 RW sw android 23 male bachelor low \n",
"3 KE en ios 68 female secondary upper_mid \n",
"4 KE en ios 34 male secondary low \n",
"\n",
" completion_seconds attention_passed q_quality q_value q_ease \\\n",
"0 897.995012 True 5 3 1 \n",
"1 935.607966 True 1 5 2 \n",
"2 1431.517701 True 5 2 5 \n",
"3 448.519416 True 5 5 5 \n",
"4 1179.970734 True 3 1 3 \n",
"\n",
" q_support nps is_detractor \n",
"0 3 4 True \n",
"1 3 5 False \n",
"2 5 7 False \n",
"3 3 10 False \n",
"4 3 5 False "
]
},
"execution_count": 38,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"\n",
"def generate_rule_based(CFG: Dict[str, Any]) -> pd.DataFrame:\n",
" n = CFG[\"rows\"]\n",
" dt_cfg = CFG.get(\"datetime_range\", {\"start\":\"2024-01-01\",\"end\":\"2025-10-01\",\"fmt\":\"%Y-%m-%d %H:%M:%S\"})\n",
" data = {}\n",
" for f in CFG[\"fields\"]:\n",
" name, t = f[\"name\"], f[\"type\"]\n",
" if t == \"uuid4\":\n",
" data[name] = [str(uuid.uuid4()) for _ in range(n)]\n",
" elif t in (\"int\",\"float\"):\n",
" data[name] = sample_numeric(f, size=n)\n",
" elif t == \"enum\":\n",
" data[name] = sample_enum(f[\"values\"], f.get(\"probs\"), size=n)\n",
" elif t == \"datetime\":\n",
" data[name] = sample_datetime(dt_cfg[\"start\"], dt_cfg[\"end\"], size=n, fmt=dt_cfg[\"fmt\"])\n",
" elif t == \"bool\":\n",
" data[name] = np.random.rand(n) < 0.9 # 90% True\n",
" else:\n",
" data[name] = [None]*n\n",
" df = pd.DataFrame(data)\n",
"\n",
" # Derive NPS roughly from likert questions\n",
" if set([\"q_quality\",\"q_value\",\"q_ease\",\"q_support\"]).issubset(df.columns):\n",
" likert_avg = df[[\"q_quality\",\"q_value\",\"q_ease\",\"q_support\"]].mean(axis=1)\n",
" df[\"nps\"] = np.clip(np.round((likert_avg - 1.0) * (10.0/4.0) + np.random.normal(0, 1.2, size=n)), 0, 10).astype(int)\n",
"\n",
" # Heuristic target: is_detractor more likely when completion high & attention failed\n",
" if \"is_detractor\" in df.columns:\n",
" base = 0.25\n",
" comp = df.get(\"completion_seconds\", pd.Series(np.zeros(n)))\n",
" attn = pd.Series(df.get(\"attention_passed\", np.ones(n))).astype(bool)\n",
" boost = (comp > 900).astype(int) + (~attn).astype(int)\n",
" p = np.clip(base + 0.15*boost, 0.01, 0.95)\n",
" df[\"is_detractor\"] = np.random.rand(n) < p\n",
"\n",
" return df\n",
"\n",
"df_rule = generate_rule_based(CFG)\n",
"df_rule.head()\n"
]
},
{
"cell_type": "markdown",
"id": "dd9eff20",
"metadata": {},
"source": [
"## 4) Validation (Pandera optional)"
]
},
{
"cell_type": "code",
"execution_count": 39,
"id": "9a4ef86a",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Validation error: {\n",
" \"SCHEMA\": {\n",
" \"WRONG_DATATYPE\": [\n",
" {\n",
" \"schema\": null,\n",
" \"column\": \"respondent_id\",\n",
" \"check\": \"dtype('int64')\",\n",
" \"error\": \"expected series 'respondent_id' to have type int64, got int32\"\n",
" },\n",
" {\n",
" \"schema\": null,\n",
" \"column\": \"age\",\n",
" \"check\": \"dtype('int64')\",\n",
" \"error\": \"expected series 'age' to have type int64, got int32\"\n",
" },\n",
" {\n",
" \"schema\": null,\n",
" \"column\": \"q_quality\",\n",
" \"check\": \"dtype('int64')\",\n",
" \"error\": \"expected series 'q_quality' to have type int64, got int32\"\n",
" },\n",
" {\n",
" \"schema\": null,\n",
" \"column\": \"q_value\",\n",
" \"check\": \"dtype('int64')\",\n",
" \"error\": \"expected series 'q_value' to have type int64, got int32\"\n",
" },\n",
" {\n",
" \"schema\": null,\n",
" \"column\": \"q_ease\",\n",
" \"check\": \"dtype('int64')\",\n",
" \"error\": \"expected series 'q_ease' to have type int64, got int32\"\n",
" },\n",
" {\n",
" \"schema\": null,\n",
" \"column\": \"q_support\",\n",
" \"check\": \"dtype('int64')\",\n",
" \"error\": \"expected series 'q_support' to have type int64, got int32\"\n",
" }\n",
" ]\n",
" }\n",
"}\n",
"{'engine': 'pandera', 'valid_rows': 800, 'invalid_rows': 0, 'notes': 'Non-strict mode.'}\n"
]
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>response_id</th>\n",
" <th>respondent_id</th>\n",
" <th>submitted_at</th>\n",
" <th>country</th>\n",
" <th>language</th>\n",
" <th>device</th>\n",
" <th>age</th>\n",
" <th>gender</th>\n",
" <th>education</th>\n",
" <th>income_band</th>\n",
" <th>completion_seconds</th>\n",
" <th>attention_passed</th>\n",
" <th>q_quality</th>\n",
" <th>q_value</th>\n",
" <th>q_ease</th>\n",
" <th>q_support</th>\n",
" <th>nps</th>\n",
" <th>is_detractor</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>f099c1b6-a4ae-4fb0-ba98-89a81008c424</td>\n",
" <td>71615</td>\n",
" <td>2024-04-13 19:02:44</td>\n",
" <td>ZA</td>\n",
" <td>en</td>\n",
" <td>web</td>\n",
" <td>47</td>\n",
" <td>male</td>\n",
" <td>secondary</td>\n",
" <td>low</td>\n",
" <td>897.995012</td>\n",
" <td>True</td>\n",
" <td>5</td>\n",
" <td>3</td>\n",
" <td>1</td>\n",
" <td>3</td>\n",
" <td>4</td>\n",
" <td>True</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>f2e20ad1-1ed1-4e33-8beb-5dd0ba23715b</td>\n",
" <td>68564</td>\n",
" <td>2024-03-05 23:30:30</td>\n",
" <td>KE</td>\n",
" <td>en</td>\n",
" <td>android</td>\n",
" <td>67</td>\n",
" <td>female</td>\n",
" <td>bachelor</td>\n",
" <td>lower_mid</td>\n",
" <td>935.607966</td>\n",
" <td>True</td>\n",
" <td>1</td>\n",
" <td>5</td>\n",
" <td>2</td>\n",
" <td>3</td>\n",
" <td>5</td>\n",
" <td>False</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>a9345f69-be75-46b9-8cd3-a276ce0a66bd</td>\n",
" <td>59689</td>\n",
" <td>2024-11-10 03:38:07</td>\n",
" <td>RW</td>\n",
" <td>sw</td>\n",
" <td>android</td>\n",
" <td>23</td>\n",
" <td>male</td>\n",
" <td>bachelor</td>\n",
" <td>low</td>\n",
" <td>1431.517701</td>\n",
" <td>True</td>\n",
" <td>5</td>\n",
" <td>2</td>\n",
" <td>5</td>\n",
" <td>5</td>\n",
" <td>7</td>\n",
" <td>False</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>b4fa8625-d153-4465-ad73-1c4a48eed2f1</td>\n",
" <td>20742</td>\n",
" <td>2024-11-19 17:40:58</td>\n",
" <td>KE</td>\n",
" <td>en</td>\n",
" <td>ios</td>\n",
" <td>68</td>\n",
" <td>female</td>\n",
" <td>secondary</td>\n",
" <td>upper_mid</td>\n",
" <td>448.519416</td>\n",
" <td>True</td>\n",
" <td>5</td>\n",
" <td>5</td>\n",
" <td>5</td>\n",
" <td>3</td>\n",
" <td>10</td>\n",
" <td>False</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>e0ad4bbc-b576-4913-8786-302f06b5e9f7</td>\n",
" <td>63459</td>\n",
" <td>2024-07-28 04:23:37</td>\n",
" <td>KE</td>\n",
" <td>en</td>\n",
" <td>ios</td>\n",
" <td>34</td>\n",
" <td>male</td>\n",
" <td>secondary</td>\n",
" <td>low</td>\n",
" <td>1179.970734</td>\n",
" <td>True</td>\n",
" <td>3</td>\n",
" <td>1</td>\n",
" <td>3</td>\n",
" <td>3</td>\n",
" <td>5</td>\n",
" <td>False</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" response_id respondent_id submitted_at \\\n",
"0 f099c1b6-a4ae-4fb0-ba98-89a81008c424 71615 2024-04-13 19:02:44 \n",
"1 f2e20ad1-1ed1-4e33-8beb-5dd0ba23715b 68564 2024-03-05 23:30:30 \n",
"2 a9345f69-be75-46b9-8cd3-a276ce0a66bd 59689 2024-11-10 03:38:07 \n",
"3 b4fa8625-d153-4465-ad73-1c4a48eed2f1 20742 2024-11-19 17:40:58 \n",
"4 e0ad4bbc-b576-4913-8786-302f06b5e9f7 63459 2024-07-28 04:23:37 \n",
"\n",
" country language device age gender education income_band \\\n",
"0 ZA en web 47 male secondary low \n",
"1 KE en android 67 female bachelor lower_mid \n",
"2 RW sw android 23 male bachelor low \n",
"3 KE en ios 68 female secondary upper_mid \n",
"4 KE en ios 34 male secondary low \n",
"\n",
" completion_seconds attention_passed q_quality q_value q_ease \\\n",
"0 897.995012 True 5 3 1 \n",
"1 935.607966 True 1 5 2 \n",
"2 1431.517701 True 5 2 5 \n",
"3 448.519416 True 5 5 5 \n",
"4 1179.970734 True 3 1 3 \n",
"\n",
" q_support nps is_detractor \n",
"0 3 4 True \n",
"1 3 5 False \n",
"2 5 7 False \n",
"3 3 10 False \n",
"4 3 5 False "
]
},
"execution_count": 39,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"\n",
"def build_pandera_schema(CFG):\n",
" if pa is None:\n",
" return None\n",
" cols = {}\n",
" for f in CFG[\"fields\"]:\n",
" t, name = f[\"type\"], f[\"name\"]\n",
" if t == \"int\": cols[name] = pa.Column(int)\n",
" elif t == \"float\": cols[name] = pa.Column(float)\n",
" elif t == \"enum\": cols[name] = pa.Column(object)\n",
" elif t == \"datetime\": cols[name] = pa.Column(object)\n",
" elif t == \"uuid4\": cols[name] = pa.Column(object)\n",
" elif t == \"bool\": cols[name] = pa.Column(bool)\n",
" else: cols[name] = pa.Column(object)\n",
" return pa.DataFrameSchema(cols) if pa is not None else None\n",
"\n",
"def validate_df(df, CFG):\n",
" schema = build_pandera_schema(CFG)\n",
" if schema is None:\n",
" return df, {\"engine\":\"basic\",\"valid_rows\": len(df), \"invalid_rows\": 0}\n",
" try:\n",
" v = schema.validate(df, lazy=True)\n",
" return v, {\"engine\":\"pandera\",\"valid_rows\": len(v), \"invalid_rows\": 0}\n",
" except Exception as e:\n",
" print(\"Validation error:\", e)\n",
" return df, {\"engine\":\"pandera\",\"valid_rows\": len(df), \"invalid_rows\": 0, \"notes\": \"Non-strict mode.\"}\n",
"\n",
"validated_rule, report_rule = validate_df(df_rule, CFG)\n",
"print(report_rule)\n",
"validated_rule.head()\n"
]
},
{
"cell_type": "markdown",
"id": "d5f1d93a",
"metadata": {},
"source": [
"## 5) Save"
]
},
{
"cell_type": "code",
"execution_count": 40,
"id": "73626b4c",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Saved: data/survey_rule_20251023T004106Z.csv\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"C:\\Users\\Joshua\\AppData\\Local\\Temp\\ipykernel_27572\\1233117399.py:3: DeprecationWarning: datetime.datetime.utcnow() is deprecated and scheduled for removal in a future version. Use timezone-aware objects to represent datetimes in UTC: datetime.datetime.now(datetime.UTC).\n",
" ts = datetime.utcnow().strftime(\"%Y%m%dT%H%M%SZ\")\n"
]
}
],
"source": [
"\n",
"from pathlib import Path\n",
"out = Path(\"data\"); out.mkdir(exist_ok=True)\n",
"ts = datetime.utcnow().strftime(\"%Y%m%dT%H%M%SZ\")\n",
"csv_path = out / f\"survey_rule_{ts}.csv\"\n",
"validated_rule.to_csv(csv_path, index=False)\n",
"print(\"Saved:\", csv_path.as_posix())\n"
]
},
{
"cell_type": "markdown",
"id": "87c89b51",
"metadata": {},
"source": [
"## 6) Optional: LLM Generator (JSON mode, retry & strict parsing)"
]
},
{
"cell_type": "code",
"execution_count": 41,
"id": "24e94771",
"metadata": {},
"outputs": [],
"source": [
"# Fixed LLM Generation Functions\n",
"def create_survey_prompt(CFG, n_rows=50):\n",
" \"\"\"Create a clear, structured prompt for survey data generation\"\"\"\n",
" fields_desc = []\n",
" for field in CFG['fields']:\n",
" name = field['name']\n",
" field_type = field['type']\n",
" \n",
" if field_type == 'int':\n",
" min_val = field.get('min', 0)\n",
" max_val = field.get('max', 100)\n",
" fields_desc.append(f\" - {name}: integer between {min_val} and {max_val}\")\n",
" elif field_type == 'float':\n",
" min_val = field.get('min', 0.0)\n",
" max_val = field.get('max', 100.0)\n",
" fields_desc.append(f\" - {name}: float between {min_val} and {max_val}\")\n",
" elif field_type == 'enum':\n",
" values = field.get('values', [])\n",
" fields_desc.append(f\" - {name}: one of {values}\")\n",
" elif field_type == 'bool':\n",
" fields_desc.append(f\" - {name}: boolean (true/false)\")\n",
" elif field_type == 'uuid4':\n",
" fields_desc.append(f\" - {name}: UUID string\")\n",
" elif field_type == 'datetime':\n",
" fmt = field.get('fmt', '%Y-%m-%d %H:%M:%S')\n",
" fields_desc.append(f\" - {name}: datetime string in format {fmt}\")\n",
" else:\n",
" fields_desc.append(f\" - {name}: {field_type}\")\n",
" \n",
" prompt = f\"\"\"Generate {n_rows} rows of realistic survey response data.\n",
"\n",
"Schema:\n",
"{chr(10).join(fields_desc)}\n",
"\n",
"CRITICAL REQUIREMENTS:\n",
"- Return a JSON object with a \"responses\" key containing an array\n",
"- Each object in the array must have all required fields\n",
"- Use realistic, diverse values for survey responses\n",
"- No trailing commas\n",
"- No comments or explanations\n",
"\n",
"Output format: JSON object with \"responses\" array containing exactly {n_rows} objects.\n",
"\n",
"Example structure:\n",
"{{\n",
" \"responses\": [\n",
" {{\n",
" \"response_id\": \"uuid-string\",\n",
" \"respondent_id\": 12345,\n",
" \"submitted_at\": \"2024-01-01 12:00:00\",\n",
" \"country\": \"KE\",\n",
" \"language\": \"en\",\n",
" \"device\": \"android\",\n",
" \"age\": 25,\n",
" \"gender\": \"female\",\n",
" \"education\": \"bachelor\",\n",
" \"income_band\": \"upper_mid\",\n",
" \"completion_seconds\": 300.5,\n",
" \"attention_passed\": true,\n",
" \"q_quality\": 4,\n",
" \"q_value\": 3,\n",
" \"q_ease\": 5,\n",
" \"q_support\": 4,\n",
" \"nps\": 8,\n",
" \"is_detractor\": false\n",
" }},\n",
" ...\n",
" ]\n",
"}}\n",
"\n",
"IMPORTANT: Return ONLY the JSON object with \"responses\" key, nothing else.\"\"\"\n",
" \n",
" return prompt\n",
"\n",
"def repair_truncated_json(content):\n",
" \"\"\"Attempt to repair truncated JSON responses\"\"\"\n",
" content = content.strip()\n",
" \n",
" # If it starts with { but doesn't end with }, try to close it\n",
" if content.startswith('{') and not content.endswith('}'):\n",
" # Find the last complete object in the responses array\n",
" responses_start = content.find('\"responses\": [')\n",
" if responses_start != -1:\n",
" # Find the last complete object\n",
" brace_count = 0\n",
" last_complete_pos = -1\n",
" in_string = False\n",
" escape_next = False\n",
" \n",
" for i, char in enumerate(content[responses_start:], responses_start):\n",
" if escape_next:\n",
" escape_next = False\n",
" continue\n",
" \n",
" if char == '\\\\':\n",
" escape_next = True\n",
" continue\n",
" \n",
" if char == '\"' and not escape_next:\n",
" in_string = not in_string\n",
" continue\n",
" \n",
" if not in_string:\n",
" if char == '{':\n",
" brace_count += 1\n",
" elif char == '}':\n",
" brace_count -= 1\n",
" if brace_count == 0:\n",
" last_complete_pos = i\n",
" break\n",
" \n",
" if last_complete_pos != -1:\n",
" # Truncate at the last complete object and close the JSON\n",
" repaired = content[:last_complete_pos + 1] + '\\n ]\\n}'\n",
" print(f\"🔧 Repaired JSON: truncated at position {last_complete_pos}\")\n",
" return repaired\n",
" \n",
" return content\n",
"\n",
"def fixed_llm_generate_batch(CFG, n_rows=50):\n",
" \"\"\"Fixed LLM generation with better prompt and error handling\"\"\"\n",
" if not os.getenv('OPENAI_API_KEY'):\n",
" print(\"No OpenAI API key, using rule-based fallback\")\n",
" tmp = dict(CFG); tmp['rows'] = n_rows\n",
" return generate_rule_based(tmp)\n",
" \n",
" try:\n",
" from openai import OpenAI\n",
" client = OpenAI()\n",
" \n",
" prompt = create_survey_prompt(CFG, n_rows)\n",
" \n",
" print(f\"🔄 Generating {n_rows} survey responses with LLM...\")\n",
" \n",
" # Calculate appropriate max_tokens based on batch size\n",
" # Roughly 200-300 tokens per row, with some buffer\n",
" estimated_tokens = n_rows * 300 + 500 # Buffer for JSON structure\n",
" max_tokens = min(max(estimated_tokens, 2000), 8000) # Between 2k-8k tokens\n",
" \n",
" print(f\"📊 Using max_tokens: {max_tokens} (estimated: {estimated_tokens})\")\n",
" \n",
" response = client.chat.completions.create(\n",
" model='gpt-4o-mini',\n",
" messages=[\n",
" {'role': 'system', 'content': 'You are a data generation expert. Generate realistic survey data in JSON format. Always return complete, valid JSON.'},\n",
" {'role': 'user', 'content': prompt}\n",
" ],\n",
" temperature=0.3,\n",
" max_tokens=max_tokens,\n",
" response_format={'type': 'json_object'}\n",
" )\n",
" \n",
" content = response.choices[0].message.content\n",
" print(f\"📝 Raw response length: {len(content)} characters\")\n",
" \n",
" # Check if response appears truncated\n",
" if not content.strip().endswith('}') and not content.strip().endswith(']'):\n",
" print(\"⚠️ Response appears truncated, attempting repair...\")\n",
" content = repair_truncated_json(content)\n",
" \n",
" # Try to extract JSON with improved logic\n",
" try:\n",
" data = json.loads(content)\n",
" print(f\"🔍 Parsed JSON type: {type(data)}\")\n",
" \n",
" if isinstance(data, list):\n",
" df = pd.DataFrame(data)\n",
" print(f\"📊 Direct array: {len(df)} rows\")\n",
" elif isinstance(data, dict):\n",
" # Check for common keys that might contain the data\n",
" for key in ['responses', 'rows', 'data', 'items', 'records', 'results', 'survey_responses']:\n",
" if key in data and isinstance(data[key], list):\n",
" df = pd.DataFrame(data[key])\n",
" print(f\"📊 Found data in '{key}': {len(df)} rows\")\n",
" break\n",
" else:\n",
" # If no standard key found, check if all values are lists/objects\n",
" list_keys = [k for k, v in data.items() if isinstance(v, list) and len(v) > 0]\n",
" if list_keys:\n",
" # Use the first list key found\n",
" key = list_keys[0]\n",
" df = pd.DataFrame(data[key])\n",
" print(f\"📊 Found data in '{key}': {len(df)} rows\")\n",
" else:\n",
" # Try to convert the dict values to a list\n",
" if all(isinstance(v, dict) for v in data.values()):\n",
" df = pd.DataFrame(list(data.values()))\n",
" print(f\"📊 Converted dict values: {len(df)} rows\")\n",
" else:\n",
" raise ValueError(f\"Unexpected JSON structure: {list(data.keys())}\")\n",
" else:\n",
" raise ValueError(f\"Unexpected JSON type: {type(data)}\")\n",
" \n",
" if len(df) == n_rows:\n",
" print(f\"✅ Successfully generated {len(df)} survey responses\")\n",
" return df\n",
" else:\n",
" print(f\"⚠️ Generated {len(df)} rows, expected {n_rows}\")\n",
" if len(df) > 0:\n",
" return df\n",
" else:\n",
" raise ValueError(\"No data generated\")\n",
" \n",
" except json.JSONDecodeError as e:\n",
" print(f\"❌ JSON parsing failed: {str(e)}\")\n",
" # Try the improved extract_strict_json function\n",
" try:\n",
" data = extract_strict_json(content)\n",
" df = pd.DataFrame(data)\n",
" print(f\"✅ Recovered with strict parsing: {len(df)} rows\")\n",
" return df\n",
" except Exception as e2:\n",
" print(f\"❌ Strict parsing also failed: {str(e2)}\")\n",
" # Print a sample of the content for debugging\n",
" print(f\"🔍 Content sample: {content[:500]}...\")\n",
" raise e2\n",
" \n",
" except Exception as e:\n",
" print(f'❌ LLM error, fallback to rule-based mock: {str(e)}')\n",
" tmp = dict(CFG); tmp['rows'] = n_rows\n",
" return generate_rule_based(tmp)\n",
"\n",
"def fixed_generate_llm(CFG, total_rows=200, batch_size=50):\n",
" \"\"\"Fixed LLM generation with adaptive batch processing\"\"\"\n",
" print(f\"🚀 Generating {total_rows} survey responses with adaptive batching\")\n",
" \n",
" # Adaptive batch sizing based on total rows\n",
" if total_rows <= 20:\n",
" optimal_batch_size = min(batch_size, total_rows)\n",
" elif total_rows <= 50:\n",
" optimal_batch_size = min(15, batch_size)\n",
" elif total_rows <= 100:\n",
" optimal_batch_size = min(10, batch_size)\n",
" else:\n",
" optimal_batch_size = min(8, batch_size)\n",
" \n",
" print(f\"📊 Using optimal batch size: {optimal_batch_size}\")\n",
" \n",
" all_dataframes = []\n",
" remaining = total_rows\n",
" \n",
" while remaining > 0:\n",
" current_batch_size = min(optimal_batch_size, remaining)\n",
" print(f\"\\n📦 Processing batch: {current_batch_size} rows (remaining: {remaining})\")\n",
" \n",
" try:\n",
" batch_df = fixed_llm_generate_batch(CFG, current_batch_size)\n",
" all_dataframes.append(batch_df)\n",
" remaining -= len(batch_df)\n",
" \n",
" # Small delay between batches to avoid rate limits\n",
" if remaining > 0:\n",
" time.sleep(1.5)\n",
" \n",
" except Exception as e:\n",
" print(f\"❌ Batch failed: {str(e)}\")\n",
" print(f\"🔄 Retrying with smaller batch size...\")\n",
" \n",
" # Try with smaller batch size\n",
" smaller_batch = max(1, current_batch_size // 2)\n",
" if smaller_batch < current_batch_size:\n",
" try:\n",
" print(f\"🔄 Retrying with {smaller_batch} rows...\")\n",
" batch_df = fixed_llm_generate_batch(CFG, smaller_batch)\n",
" all_dataframes.append(batch_df)\n",
" remaining -= len(batch_df)\n",
" continue\n",
" except Exception as e2:\n",
" print(f\"❌ Retry also failed: {str(e2)}\")\n",
" \n",
" print(f\"Using rule-based fallback for remaining {remaining} rows\")\n",
" fallback_df = generate_rule_based(CFG, remaining)\n",
" all_dataframes.append(fallback_df)\n",
" break\n",
" \n",
" if all_dataframes:\n",
" result = pd.concat(all_dataframes, ignore_index=True)\n",
" print(f\"✅ Generated total: {len(result)} survey responses\")\n",
" return result\n",
" else:\n",
" print(\"❌ No data generated\")\n",
" return pd.DataFrame()\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e1af410e",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"🧪 Testing LLM generation...\n",
"🔄 Generating 10 survey responses with LLM...\n",
"📊 Using max_tokens: 3500 (estimated: 3500)\n",
"📝 Raw response length: 5236 characters\n",
"🔍 Parsed JSON type: <class 'dict'>\n",
"📊 Found data in 'responses': 10 rows\n",
"✅ Successfully generated 10 survey responses\n",
"\n",
"📊 Generated dataset shape: (10, 18)\n",
"\n",
"📋 First few rows:\n",
" response_id respondent_id submitted_at \\\n",
"0 f3e9b9d1-4e9e-4f8a-9b5c-7e3cbb1c4e5e 10234 2023-10-01 14:23:45 \n",
"1 a1c5f6d3-1f5b-4e8a-8c7a-5e2c3f4b8e1b 20456 2023-10-01 15:10:12 \n",
"2 c2b3e4f5-5d6e-4b8a-9f3c-8e1a2f9b4e3c 30567 2023-10-01 16:45:30 \n",
"3 d4e5f6b7-6e8f-4b9a-8c7d-9e2f3c4b5e6f 40678 2023-10-01 17:30:00 \n",
"4 e5f6a7b8-7f9a-4c0a-9e2f-1e3c4b5e6f7a 50789 2023-10-01 18:15:15 \n",
"\n",
" country language device age gender education income_band \\\n",
"0 KE en android 29 female bachelor upper_mid \n",
"1 UG sw web 34 male secondary lower_mid \n",
"2 TZ en ios 42 nonbinary diploma high \n",
"3 RW sw android 27 female bachelor upper_mid \n",
"4 NG en web 36 male postgraduate high \n",
"\n",
" completion_seconds attention_passed q_quality q_value q_ease \\\n",
"0 450.0 True 4 5 4 \n",
"1 600.5 True 3 4 3 \n",
"2 720.0 True 5 5 5 \n",
"3 390.0 True 4 4 4 \n",
"4 800.0 True 5 5 5 \n",
"\n",
" q_support nps is_detractor \n",
"0 5 9 False \n",
"1 4 7 False \n",
"2 5 10 False \n",
"3 4 8 False \n",
"4 5 9 False \n",
"\n",
"📈 Data types:\n",
"response_id object\n",
"respondent_id int64\n",
"submitted_at object\n",
"country object\n",
"language object\n",
"device object\n",
"age int64\n",
"gender object\n",
"education object\n",
"income_band object\n",
"completion_seconds float64\n",
"attention_passed bool\n",
"q_quality int64\n",
"q_value int64\n",
"q_ease int64\n",
"q_support int64\n",
"nps int64\n",
"is_detractor bool\n",
"dtype: object\n"
]
}
],
"source": [
"# Test the fixed LLM generation\n",
"print(\"🧪 Testing LLM generation...\")\n",
"\n",
"# Test with small dataset first\n",
"test_df = fixed_llm_generate_batch(CFG, 10)\n",
"print(f\"\\n📊 Generated dataset shape: {test_df.shape}\")\n",
"print(f\"\\n📋 First few rows:\")\n",
"print(test_df.head())\n",
"print(f\"\\n📈 Data types:\")\n",
"print(test_df.dtypes)\n",
"\n",
"# Debug function to see what the LLM is actually returning\n",
"def debug_llm_response(CFG, n_rows=5):\n",
" \"\"\"Debug function to see raw LLM response\"\"\"\n",
" if not os.getenv('OPENAI_API_KEY'):\n",
" print(\"No OpenAI API key available for debugging\")\n",
" return\n",
" \n",
" try:\n",
" from openai import OpenAI\n",
" client = OpenAI()\n",
" \n",
" prompt = create_survey_prompt(CFG, n_rows)\n",
" \n",
" print(f\"\\n🔍 DEBUG: Testing with {n_rows} rows\")\n",
" print(f\"📝 Prompt length: {len(prompt)} characters\")\n",
" \n",
" response = client.chat.completions.create(\n",
" model='gpt-4o-mini',\n",
" messages=[\n",
" {'role': 'system', 'content': 'You are a data generation expert. Generate realistic survey data in JSON format.'},\n",
" {'role': 'user', 'content': prompt}\n",
" ],\n",
" temperature=0.3,\n",
" max_tokens=2000,\n",
" response_format={'type': 'json_object'}\n",
" )\n",
" \n",
" content = response.choices[0].message.content\n",
" print(f\"📝 Raw response length: {len(content)} characters\")\n",
" print(f\"🔍 First 200 characters: {content[:200]}\")\n",
" print(f\"🔍 Last 200 characters: {content[-200:]}\")\n",
" \n",
" # Try to parse\n",
" try:\n",
" data = json.loads(content)\n",
" print(f\"✅ JSON parsed successfully\")\n",
" print(f\"🔍 Data type: {type(data)}\")\n",
" if isinstance(data, dict):\n",
" print(f\"🔍 Dict keys: {list(data.keys())}\")\n",
" elif isinstance(data, list):\n",
" print(f\"🔍 List length: {len(data)}\")\n",
" except Exception as e:\n",
" print(f\"❌ JSON parsing failed: {str(e)}\")\n",
" \n",
" except Exception as e:\n",
" print(f\"❌ Debug failed: {str(e)}\")\n"
]
},
{
"cell_type": "code",
"execution_count": 43,
"id": "75c90739",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"🧪 Testing the fixed LLM generation...\n",
"🔄 Generating 5 survey responses with LLM...\n",
"📊 Using max_tokens: 2000 (estimated: 2000)\n",
"📝 Raw response length: 2629 characters\n",
"🔍 Parsed JSON type: <class 'dict'>\n",
"📊 Found data in 'responses': 5 rows\n",
"✅ Successfully generated 5 survey responses\n",
"\n",
"📊 Generated dataset shape: (5, 18)\n",
"\n",
"📋 First few rows:\n",
" response_id respondent_id submitted_at \\\n",
"0 d8b1c6f3-6f7a-4b4f-9c5f-3a5f8b6e2f1e 12345 2023-10-01 14:30:00 \n",
"1 f3a8e3c1-9b4e-4e5e-9c2b-8f5e3c9b1f3d 67890 2023-10-01 15:00:00 \n",
"2 c9c8e3f1-2b4f-4a6c-8c2e-2a5f3c8e1f2b 54321 2023-10-01 16:15:00 \n",
"3 a5b3c6d2-1e4f-4c5e-9a1f-1f6a7b8e3c9f 98765 2023-10-01 17:45:00 \n",
"4 b8f4c3e2-2e4f-4c5e-8a2f-4c5e3b8e2f1a 13579 2023-10-01 18:30:00 \n",
"\n",
" country language device age gender education income_band \\\n",
"0 KE en android 29 female bachelor upper_mid \n",
"1 UG sw web 34 male diploma lower_mid \n",
"2 TZ en ios 42 nonbinary postgraduate high \n",
"3 RW sw android 27 female secondary low \n",
"4 NG en web 55 male bachelor upper_mid \n",
"\n",
" completion_seconds attention_passed q_quality q_value q_ease \\\n",
"0 420.0 True 5 4 4 \n",
"1 600.0 True 3 3 2 \n",
"2 300.5 True 4 5 4 \n",
"3 720.0 False 2 3 3 \n",
"4 540.0 True 5 5 5 \n",
"\n",
" q_support nps is_detractor \n",
"0 5 9 False \n",
"1 4 5 False \n",
"2 5 10 False \n",
"3 2 3 True \n",
"4 5 8 False \n",
"\n",
"📈 Data types:\n",
"response_id object\n",
"respondent_id int64\n",
"submitted_at object\n",
"country object\n",
"language object\n",
"device object\n",
"age int64\n",
"gender object\n",
"education object\n",
"income_band object\n",
"completion_seconds float64\n",
"attention_passed bool\n",
"q_quality int64\n",
"q_value int64\n",
"q_ease int64\n",
"q_support int64\n",
"nps int64\n",
"is_detractor bool\n",
"dtype: object\n",
"\n",
"✅ SUCCESS! LLM generation is now working!\n",
"📊 Generated 5 survey responses using LLM\n"
]
}
],
"source": [
"# Test the fixed implementation\n",
"print(\"🧪 Testing the fixed LLM generation...\")\n",
"\n",
"# Test with small dataset\n",
"test_df = fixed_llm_generate_batch(CFG, 5)\n",
"print(f\"\\n📊 Generated dataset shape: {test_df.shape}\")\n",
"print(f\"\\n📋 First few rows:\")\n",
"print(test_df.head())\n",
"print(f\"\\n📈 Data types:\")\n",
"print(test_df.dtypes)\n",
"\n",
"if not test_df.empty:\n",
" print(f\"\\n✅ SUCCESS! LLM generation is now working!\")\n",
" print(f\"📊 Generated {len(test_df)} survey responses using LLM\")\n",
"else:\n",
" print(f\"\\n❌ Still having issues with LLM generation\")\n"
]
},
{
"cell_type": "code",
"execution_count": 44,
"id": "dd83b842",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"🚀 Testing larger dataset generation...\n",
"🚀 Generating 100 survey responses with adaptive batching\n",
"📊 Using optimal batch size: 10\n",
"\n",
"📦 Processing batch: 10 rows (remaining: 100)\n",
"🔄 Generating 10 survey responses with LLM...\n",
"📊 Using max_tokens: 3500 (estimated: 3500)\n",
"📝 Raw response length: 5238 characters\n",
"🔍 Parsed JSON type: <class 'dict'>\n",
"📊 Found data in 'responses': 10 rows\n",
"✅ Successfully generated 10 survey responses\n",
"\n",
"📦 Processing batch: 10 rows (remaining: 90)\n",
"🔄 Generating 10 survey responses with LLM...\n",
"📊 Using max_tokens: 3500 (estimated: 3500)\n",
"📝 Raw response length: 5235 characters\n",
"🔍 Parsed JSON type: <class 'dict'>\n",
"📊 Found data in 'responses': 10 rows\n",
"✅ Successfully generated 10 survey responses\n",
"\n",
"📦 Processing batch: 10 rows (remaining: 80)\n",
"🔄 Generating 10 survey responses with LLM...\n",
"📊 Using max_tokens: 3500 (estimated: 3500)\n",
"📝 Raw response length: 5232 characters\n",
"🔍 Parsed JSON type: <class 'dict'>\n",
"📊 Found data in 'responses': 10 rows\n",
"✅ Successfully generated 10 survey responses\n",
"\n",
"📦 Processing batch: 10 rows (remaining: 70)\n",
"🔄 Generating 10 survey responses with LLM...\n",
"📊 Using max_tokens: 3500 (estimated: 3500)\n",
"📝 Raw response length: 5239 characters\n",
"🔍 Parsed JSON type: <class 'dict'>\n",
"📊 Found data in 'responses': 10 rows\n",
"✅ Successfully generated 10 survey responses\n",
"\n",
"📦 Processing batch: 10 rows (remaining: 60)\n",
"🔄 Generating 10 survey responses with LLM...\n",
"📊 Using max_tokens: 3500 (estimated: 3500)\n",
"📝 Raw response length: 5238 characters\n",
"🔍 Parsed JSON type: <class 'dict'>\n",
"📊 Found data in 'responses': 10 rows\n",
"✅ Successfully generated 10 survey responses\n",
"\n",
"📦 Processing batch: 10 rows (remaining: 50)\n",
"🔄 Generating 10 survey responses with LLM...\n",
"📊 Using max_tokens: 3500 (estimated: 3500)\n",
"📝 Raw response length: 5236 characters\n",
"🔍 Parsed JSON type: <class 'dict'>\n",
"📊 Found data in 'responses': 10 rows\n",
"✅ Successfully generated 10 survey responses\n",
"\n",
"📦 Processing batch: 10 rows (remaining: 40)\n",
"🔄 Generating 10 survey responses with LLM...\n",
"📊 Using max_tokens: 3500 (estimated: 3500)\n",
"📝 Raw response length: 5229 characters\n",
"🔍 Parsed JSON type: <class 'dict'>\n",
"📊 Found data in 'responses': 10 rows\n",
"✅ Successfully generated 10 survey responses\n",
"\n",
"📦 Processing batch: 10 rows (remaining: 30)\n",
"🔄 Generating 10 survey responses with LLM...\n",
"📊 Using max_tokens: 3500 (estimated: 3500)\n",
"📝 Raw response length: 5244 characters\n",
"🔍 Parsed JSON type: <class 'dict'>\n",
"📊 Found data in 'responses': 10 rows\n",
"✅ Successfully generated 10 survey responses\n",
"\n",
"📦 Processing batch: 10 rows (remaining: 20)\n",
"🔄 Generating 10 survey responses with LLM...\n",
"📊 Using max_tokens: 3500 (estimated: 3500)\n",
"📝 Raw response length: 5234 characters\n",
"🔍 Parsed JSON type: <class 'dict'>\n",
"📊 Found data in 'responses': 10 rows\n",
"✅ Successfully generated 10 survey responses\n",
"\n",
"📦 Processing batch: 10 rows (remaining: 10)\n",
"🔄 Generating 10 survey responses with LLM...\n",
"📊 Using max_tokens: 3500 (estimated: 3500)\n",
"📝 Raw response length: 5238 characters\n",
"🔍 Parsed JSON type: <class 'dict'>\n",
"📊 Found data in 'responses': 10 rows\n",
"✅ Successfully generated 10 survey responses\n",
"✅ Generated total: 100 survey responses\n",
"\n",
"📊 Large dataset shape: (100, 18)\n",
"\n",
"📈 Summary statistics:\n",
" respondent_id age completion_seconds q_quality q_value \\\n",
"count 100.000000 100.000000 100.000000 100.000000 100.000000 \n",
"mean 33513.700000 34.070000 588.525000 3.740000 3.910000 \n",
"std 29233.800863 7.835757 230.530212 1.001211 0.995901 \n",
"min 10001.000000 22.000000 120.500000 2.000000 2.000000 \n",
"25% 10009.000000 28.000000 420.375000 3.000000 3.000000 \n",
"50% 15122.500000 33.000000 600.000000 4.000000 4.000000 \n",
"75% 55955.750000 39.250000 720.000000 5.000000 5.000000 \n",
"max 98765.000000 50.000000 1500.000000 5.000000 5.000000 \n",
"\n",
" q_ease q_support nps \n",
"count 100.000000 100.000000 100.000000 \n",
"mean 3.900000 3.910000 6.990000 \n",
"std 0.937437 0.985706 2.333312 \n",
"min 2.000000 2.000000 2.000000 \n",
"25% 3.000000 3.000000 5.000000 \n",
"50% 4.000000 4.000000 7.000000 \n",
"75% 5.000000 5.000000 9.000000 \n",
"max 5.000000 5.000000 10.000000 \n",
"💾 Saved: data\\survey_llm_fixed_20251023T005139Z.csv\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"C:\\Users\\Joshua\\AppData\\Local\\Temp\\ipykernel_27572\\2716383900.py:12: DeprecationWarning: datetime.datetime.utcnow() is deprecated and scheduled for removal in a future version. Use timezone-aware objects to represent datetimes in UTC: datetime.datetime.now(datetime.UTC).\n",
" ts = datetime.utcnow().strftime(\"%Y%m%dT%H%M%SZ\")\n"
]
}
],
"source": [
"#Test larger dataset generation \n",
"print(\"🚀 Testing larger dataset generation...\")\n",
"large_df = fixed_generate_llm(CFG, total_rows=100, batch_size=25)\n",
"if not large_df.empty:\n",
" print(f\"\\n📊 Large dataset shape: {large_df.shape}\")\n",
" print(f\"\\n📈 Summary statistics:\")\n",
" print(large_df.describe())\n",
" \n",
" # Save the results\n",
" from pathlib import Path\n",
" out = Path(\"data\"); out.mkdir(exist_ok=True)\n",
" ts = datetime.utcnow().strftime(\"%Y%m%dT%H%M%SZ\")\n",
" csv_path = out / f\"survey_llm_fixed_{ts}.csv\"\n",
" large_df.to_csv(csv_path, index=False)\n",
" print(f\"💾 Saved: {csv_path}\")\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "6029d3e2",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"LLM available: True\n"
]
}
],
"source": [
"\n",
"def build_json_schema(CFG):\n",
" schema = {'type':'array','items':{'type':'object','properties':{},'required':[]}}\n",
" props = schema['items']['properties']; req = schema['items']['required']\n",
" for f in CFG['fields']:\n",
" name, t = f['name'], f['type']\n",
" req.append(name)\n",
" if t in ('int','float'): props[name] = {'type':'number' if t=='float' else 'integer'}\n",
" elif t == 'enum': props[name] = {'type':'string','enum': f['values']}\n",
" elif t in ('uuid4','datetime'): props[name] = {'type':'string'}\n",
" elif t == 'bool': props[name] = {'type':'boolean'}\n",
" else: props[name] = {'type':'string'}\n",
" return schema\n",
"\n",
"PROMPT_PREAMBLE = (\n",
" \"You are a data generator. Return ONLY JSON. \"\n",
" \"Respond as a JSON object with key 'rows' whose value is an array of exactly N objects. \"\n",
" \"No prose, no code fences, no trailing commas.\"\n",
")\n",
"\n",
"def render_prompt(CFG, n_rows=100):\n",
" minimal_cfg = {'fields': []}\n",
" for f in CFG['fields']:\n",
" base = {k: f[k] for k in ['name','type'] if k in f}\n",
" if 'min' in f and 'max' in f: base.update({'min': f['min'], 'max': f['max']})\n",
" if 'values' in f: base.update({'values': f['values']})\n",
" if 'fmt' in f: base.update({'fmt': f['fmt']})\n",
" minimal_cfg['fields'].append(base)\n",
" return {\n",
" 'preamble': PROMPT_PREAMBLE,\n",
" 'n_rows': n_rows,\n",
" 'schema': build_json_schema(CFG),\n",
" 'constraints': minimal_cfg,\n",
" 'instruction': f\"Return ONLY this structure: {{'rows': [ ... exactly {n_rows} objects ... ]}}\"\n",
" }\n",
"\n",
"def parse_llm_json_to_df(raw: str) -> pd.DataFrame:\n",
" try:\n",
" obj = json.loads(raw)\n",
" if isinstance(obj, dict) and isinstance(obj.get('rows'), list):\n",
" return pd.DataFrame(obj['rows'])\n",
" except Exception:\n",
" pass\n",
" data = extract_strict_json(raw)\n",
" return pd.DataFrame(data)\n",
"\n",
"USE_LLM = bool(os.getenv('OPENAI_API_KEY'))\n",
"print('LLM available:', USE_LLM)\n",
"\n",
"def llm_generate_batch(CFG, n_rows=50):\n",
" if USE_LLM:\n",
" try:\n",
" from openai import OpenAI\n",
" client = OpenAI()\n",
" prompt = json.dumps(render_prompt(CFG, n_rows))\n",
" resp = client.chat.completions.create(\n",
" model='gpt-4o-mini',\n",
" response_format={'type': 'json_object'},\n",
" messages=[\n",
" {'role':'system','content':'You output strict JSON only.'},\n",
" {'role':'user','content': prompt}\n",
" ],\n",
" temperature=0.2,\n",
" max_tokens=8192,\n",
" )\n",
" raw = resp.choices[0].message.content\n",
" try:\n",
" return parse_llm_json_to_df(raw)\n",
" except Exception:\n",
" stricter = (\n",
" prompt\n",
" + \"\\nReturn ONLY a JSON object structured as: \"\n",
" + \"{\\\"rows\\\": [ ... exactly N objects ... ]}. \"\n",
" + \"No prose, no explanations.\"\n",
" )\n",
" resp2 = client.chat.completions.create(\n",
" model='gpt-4o-mini',\n",
" response_format={'type': 'json_object'},\n",
" messages=[\n",
" {'role':'system','content':'You output strict JSON only.'},\n",
" {'role':'user','content': stricter}\n",
" ],\n",
" temperature=0.2,\n",
" max_tokens=8192,\n",
" )\n",
" raw2 = resp2.choices[0].message.content\n",
" return parse_llm_json_to_df(raw2)\n",
" except Exception as e:\n",
" print('LLM error, fallback to rule-based mock:', e)\n",
" tmp = dict(CFG); tmp['rows'] = n_rows\n",
" return generate_rule_based(tmp)\n",
"\n",
"def generate_llm(CFG, total_rows=200, batch_size=50):\n",
" dfs = []; remaining = total_rows\n",
" while remaining > 0:\n",
" b = min(batch_size, remaining)\n",
" dfs.append(llm_generate_batch(CFG, n_rows=b))\n",
" remaining -= b\n",
" time.sleep(0.2)\n",
" return pd.concat(dfs, ignore_index=True)\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "2e759087",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"LLM error, fallback to rule-based mock: No JSON array found in model output.\n"
]
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>response_id</th>\n",
" <th>respondent_id</th>\n",
" <th>submitted_at</th>\n",
" <th>country</th>\n",
" <th>language</th>\n",
" <th>device</th>\n",
" <th>age</th>\n",
" <th>gender</th>\n",
" <th>education</th>\n",
" <th>income_band</th>\n",
" <th>completion_seconds</th>\n",
" <th>attention_passed</th>\n",
" <th>q_quality</th>\n",
" <th>q_value</th>\n",
" <th>q_ease</th>\n",
" <th>q_support</th>\n",
" <th>nps</th>\n",
" <th>is_detractor</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>9e7811bd-27ee-4b7c-9b7a-c98441e337f0</td>\n",
" <td>40160</td>\n",
" <td>2024-08-18 19:10:06</td>\n",
" <td>KE</td>\n",
" <td>sw</td>\n",
" <td>web</td>\n",
" <td>28</td>\n",
" <td>male</td>\n",
" <td>secondary</td>\n",
" <td>lower_mid</td>\n",
" <td>1800.000000</td>\n",
" <td>True</td>\n",
" <td>4</td>\n",
" <td>3</td>\n",
" <td>3</td>\n",
" <td>3</td>\n",
" <td>4</td>\n",
" <td>True</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>85ec8b90-5468-4880-8309-e325da14d877</td>\n",
" <td>55381</td>\n",
" <td>2025-01-24 12:21:13</td>\n",
" <td>TZ</td>\n",
" <td>sw</td>\n",
" <td>ios</td>\n",
" <td>23</td>\n",
" <td>female</td>\n",
" <td>bachelor</td>\n",
" <td>high</td>\n",
" <td>431.412783</td>\n",
" <td>True</td>\n",
" <td>3</td>\n",
" <td>2</td>\n",
" <td>3</td>\n",
" <td>4</td>\n",
" <td>4</td>\n",
" <td>False</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>498dff10-040f-4206-8170-dfce0d5a69f0</td>\n",
" <td>48338</td>\n",
" <td>2025-07-15 22:21:54</td>\n",
" <td>TZ</td>\n",
" <td>en</td>\n",
" <td>ios</td>\n",
" <td>49</td>\n",
" <td>male</td>\n",
" <td>bachelor</td>\n",
" <td>low</td>\n",
" <td>1800.000000</td>\n",
" <td>True</td>\n",
" <td>2</td>\n",
" <td>3</td>\n",
" <td>3</td>\n",
" <td>1</td>\n",
" <td>3</td>\n",
" <td>False</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>ddf11d94-5d6e-4322-9811-4e763f5ed46b</td>\n",
" <td>59925</td>\n",
" <td>2025-01-27 00:16:57</td>\n",
" <td>KE</td>\n",
" <td>en</td>\n",
" <td>web</td>\n",
" <td>22</td>\n",
" <td>male</td>\n",
" <td>bachelor</td>\n",
" <td>upper_mid</td>\n",
" <td>656.050991</td>\n",
" <td>True</td>\n",
" <td>4</td>\n",
" <td>4</td>\n",
" <td>1</td>\n",
" <td>3</td>\n",
" <td>5</td>\n",
" <td>False</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>2ef22a0c-fd13-4798-9276-f43831b8f7bc</td>\n",
" <td>68993</td>\n",
" <td>2024-08-19 04:21:49</td>\n",
" <td>KE</td>\n",
" <td>en</td>\n",
" <td>android</td>\n",
" <td>40</td>\n",
" <td>male</td>\n",
" <td>secondary</td>\n",
" <td>lower_mid</td>\n",
" <td>1553.938944</td>\n",
" <td>True</td>\n",
" <td>2</td>\n",
" <td>2</td>\n",
" <td>5</td>\n",
" <td>1</td>\n",
" <td>5</td>\n",
" <td>False</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" response_id respondent_id submitted_at \\\n",
"0 9e7811bd-27ee-4b7c-9b7a-c98441e337f0 40160 2024-08-18 19:10:06 \n",
"1 85ec8b90-5468-4880-8309-e325da14d877 55381 2025-01-24 12:21:13 \n",
"2 498dff10-040f-4206-8170-dfce0d5a69f0 48338 2025-07-15 22:21:54 \n",
"3 ddf11d94-5d6e-4322-9811-4e763f5ed46b 59925 2025-01-27 00:16:57 \n",
"4 2ef22a0c-fd13-4798-9276-f43831b8f7bc 68993 2024-08-19 04:21:49 \n",
"\n",
" country language device age gender education income_band \\\n",
"0 KE sw web 28 male secondary lower_mid \n",
"1 TZ sw ios 23 female bachelor high \n",
"2 TZ en ios 49 male bachelor low \n",
"3 KE en web 22 male bachelor upper_mid \n",
"4 KE en android 40 male secondary lower_mid \n",
"\n",
" completion_seconds attention_passed q_quality q_value q_ease \\\n",
"0 1800.000000 True 4 3 3 \n",
"1 431.412783 True 3 2 3 \n",
"2 1800.000000 True 2 3 3 \n",
"3 656.050991 True 4 4 1 \n",
"4 1553.938944 True 2 2 5 \n",
"\n",
" q_support nps is_detractor \n",
"0 3 4 True \n",
"1 4 4 False \n",
"2 1 3 False \n",
"3 3 5 False \n",
"4 1 5 False "
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"df_llm = generate_llm(CFG, total_rows=100, batch_size=50)\n",
"df_llm.head()"
]
},
{
"cell_type": "code",
"execution_count": 46,
"id": "6d4908ad",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"🧪 Testing improved LLM generation with adaptive batching...\n",
"\n",
"📦 Testing small batch (10 rows)...\n",
"🔄 Generating 10 survey responses with LLM...\n",
"📊 Using max_tokens: 3500 (estimated: 3500)\n",
"📝 Raw response length: 5233 characters\n",
"🔍 Parsed JSON type: <class 'dict'>\n",
"📊 Found data in 'responses': 10 rows\n",
"✅ Successfully generated 10 survey responses\n",
"✅ Small batch result: 10 rows\n",
"\n",
"📦 Testing medium dataset (30 rows) with adaptive batching...\n",
"🚀 Generating 30 survey responses with adaptive batching\n",
"📊 Using optimal batch size: 15\n",
"\n",
"📦 Processing batch: 15 rows (remaining: 30)\n",
"🔄 Generating 15 survey responses with LLM...\n",
"📊 Using max_tokens: 5000 (estimated: 5000)\n",
"📝 Raw response length: 7839 characters\n",
"🔍 Parsed JSON type: <class 'dict'>\n",
"📊 Found data in 'responses': 15 rows\n",
"✅ Successfully generated 15 survey responses\n",
"\n",
"📦 Processing batch: 15 rows (remaining: 15)\n",
"🔄 Generating 15 survey responses with LLM...\n",
"📊 Using max_tokens: 5000 (estimated: 5000)\n",
"📝 Raw response length: 7841 characters\n",
"🔍 Parsed JSON type: <class 'dict'>\n",
"📊 Found data in 'responses': 15 rows\n",
"✅ Successfully generated 15 survey responses\n",
"✅ Generated total: 30 survey responses\n",
"✅ Medium dataset result: 30 rows\n",
"\n",
"📊 Dataset shape: (30, 18)\n",
"\n",
"📋 First few rows:\n",
" response_id respondent_id submitted_at \\\n",
"0 d1e5c4a3-4b1f-4f6b-8f9e-9f1e1f2e3d4c 10001 2023-10-01 14:30:00 \n",
"1 c2b1d4a6-7f8e-4c5c-9d8f-1e2c3b4a5e6f 10002 2023-10-01 15:00:00 \n",
"2 e3f2c5b7-8a2d-4c8e-9f1b-2c3d4e5f6a7b 10003 2023-10-01 15:30:00 \n",
"3 f4a5b6c8-9d3e-4b1f-9f2c-3d4e5f6a7b8c 10004 2023-10-01 16:00:00 \n",
"4 g5b6c7d9-0e4f-4b2a-8f3d-4e5f6a7b8c9d 10005 2023-10-01 16:30:00 \n",
"\n",
" country language device age gender education income_band \\\n",
"0 KE en android 28 female bachelor upper_mid \n",
"1 UG sw web 35 male diploma lower_mid \n",
"2 TZ en ios 42 nonbinary postgraduate high \n",
"3 RW sw web 29 female secondary upper_mid \n",
"4 NG en android 50 male bachelor high \n",
"\n",
" completion_seconds attention_passed q_quality q_value q_ease \\\n",
"0 450.0 True 5 4 5 \n",
"1 600.0 True 3 2 4 \n",
"2 720.0 True 4 5 4 \n",
"3 300.0 True 3 3 3 \n",
"4 540.0 True 5 5 5 \n",
"\n",
" q_support nps is_detractor \n",
"0 4 9 False \n",
"1 3 5 False \n",
"2 5 10 False \n",
"3 4 6 False \n",
"4 5 10 False \n",
"💾 Saved: data\\survey_adaptive_batch_20251023T005927Z.csv\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"C:\\Users\\Joshua\\AppData\\Local\\Temp\\ipykernel_27572\\1770033334.py:22: DeprecationWarning: datetime.datetime.utcnow() is deprecated and scheduled for removal in a future version. Use timezone-aware objects to represent datetimes in UTC: datetime.datetime.now(datetime.UTC).\n",
" ts = datetime.utcnow().strftime(\"%Y%m%dT%H%M%SZ\")\n"
]
}
],
"source": [
"# Test the improved LLM generation with adaptive batching\n",
"print(\"🧪 Testing improved LLM generation with adaptive batching...\")\n",
"\n",
"# Test with smaller dataset first\n",
"print(\"\\n📦 Testing small batch (10 rows)...\")\n",
"small_df = fixed_llm_generate_batch(CFG, 10)\n",
"print(f\"✅ Small batch result: {len(small_df)} rows\")\n",
"\n",
"# Test with medium dataset using adaptive batching\n",
"print(\"\\n📦 Testing medium dataset (30 rows) with adaptive batching...\")\n",
"medium_df = fixed_generate_llm(CFG, total_rows=30, batch_size=15)\n",
"print(f\"✅ Medium dataset result: {len(medium_df)} rows\")\n",
"\n",
"if not medium_df.empty:\n",
" print(f\"\\n📊 Dataset shape: {medium_df.shape}\")\n",
" print(f\"\\n📋 First few rows:\")\n",
" print(medium_df.head())\n",
" \n",
" # Save the results\n",
" from pathlib import Path\n",
" out = Path(\"data\"); out.mkdir(exist_ok=True)\n",
" ts = datetime.utcnow().strftime(\"%Y%m%dT%H%M%SZ\")\n",
" csv_path = out / f\"survey_adaptive_batch_{ts}.csv\"\n",
" medium_df.to_csv(csv_path, index=False)\n",
" print(f\"💾 Saved: {csv_path}\")\n",
"else:\n",
" print(\"❌ Medium dataset generation failed\")\n"
]
}
],
"metadata": {
"kernelspec": {
"display_name": ".venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.12"
}
},
"nbformat": 4,
"nbformat_minor": 5
}