diff --git a/week3/community-contributions/week3_Exercise_survey_Dataset_Generation.ipynb b/week3/community-contributions/week3_Exercise_survey_Dataset_Generation.ipynb
new file mode 100644
index 0000000..a11fa96
--- /dev/null
+++ b/week3/community-contributions/week3_Exercise_survey_Dataset_Generation.ipynb
@@ -0,0 +1,1897 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "id": "a8dbb4e8",
+ "metadata": {},
+ "source": [
+ "# ๐งช Survey Synthetic Dataset Generator โ Week 3 Task"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 34,
+ "id": "8d86f629",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "โ
Base libraries ready. Pandera available: True\n"
+ ]
+ }
+ ],
+ "source": [
+ "\n",
+ "import os, re, json, time, uuid, math, random\n",
+ "from datetime import datetime, timedelta\n",
+ "from typing import List, Dict, Any\n",
+ "import numpy as np, pandas as pd\n",
+ "import pandera.pandas as pa\n",
+ "random.seed(7); np.random.seed(7)\n",
+ "print(\"โ
Base libraries ready. Pandera available:\", pa is not None)\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 23,
+ "id": "f196ae73",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "\n",
+ "def extract_strict_json(text: str):\n",
+ " \"\"\"Improved JSON extraction with multiple fallback strategies\"\"\"\n",
+ " if text is None:\n",
+ " raise ValueError(\"Empty model output.\")\n",
+ " \n",
+ " t = text.strip()\n",
+ " \n",
+ " # Strategy 1: Direct JSON parsing\n",
+ " try:\n",
+ " obj = json.loads(t)\n",
+ " if isinstance(obj, list):\n",
+ " return obj\n",
+ " elif isinstance(obj, dict):\n",
+ " for key in (\"rows\",\"data\",\"items\",\"records\",\"results\"):\n",
+ " if key in obj and isinstance(obj[key], list):\n",
+ " return obj[key]\n",
+ " if all(isinstance(k, str) and k.isdigit() for k in obj.keys()):\n",
+ " return [obj[k] for k in sorted(obj.keys(), key=int)]\n",
+ " except json.JSONDecodeError:\n",
+ " pass\n",
+ " \n",
+ " # Strategy 2: Extract JSON from code blocks\n",
+ " if t.startswith(\"```\"):\n",
+ " t = re.sub(r\"^```(?:json)?\\s*|\\s*```$\", \"\", t, flags=re.IGNORECASE|re.MULTILINE).strip()\n",
+ " \n",
+ " # Strategy 3: Find JSON array in text\n",
+ " start, end = t.find('['), t.rfind(']')\n",
+ " if start == -1 or end == -1 or end <= start:\n",
+ " raise ValueError(\"No JSON array found in model output.\")\n",
+ " \n",
+ " t = t[start:end+1]\n",
+ " \n",
+ " # Strategy 4: Fix common JSON issues\n",
+ " t = re.sub(r\",\\s*([\\]}])\", r\"\\1\", t) # Remove trailing commas\n",
+ " t = re.sub(r\"\\bNaN\\b|\\bInfinity\\b|\\b-Infinity\\b\", \"null\", t) # Replace NaN/Infinity\n",
+ " t = t.replace(\"\\u00a0\", \" \").replace(\"\\u200b\", \"\") # Remove invisible characters\n",
+ " \n",
+ " try:\n",
+ " return json.loads(t)\n",
+ " except json.JSONDecodeError as e:\n",
+ " raise ValueError(f\"Could not parse JSON: {str(e)}. Text: {t[:200]}...\")\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "3670fa0d",
+ "metadata": {},
+ "source": [
+ "## 1) Configuration"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 36,
+ "id": "d16bd03a",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Loaded config for 800 rows and 18 fields.\n"
+ ]
+ }
+ ],
+ "source": [
+ "\n",
+ "CFG = {\n",
+ " \"rows\": 800,\n",
+ " \"datetime_range\": {\"start\": \"2024-01-01\", \"end\": \"2025-10-01\", \"fmt\": \"%Y-%m-%d %H:%M:%S\"},\n",
+ " \"fields\": [\n",
+ " {\"name\": \"response_id\", \"type\": \"uuid4\"},\n",
+ " {\"name\": \"respondent_id\", \"type\": \"int\", \"min\": 10000, \"max\": 99999},\n",
+ " {\"name\": \"submitted_at\", \"type\": \"datetime\"},\n",
+ " {\"name\": \"country\", \"type\": \"enum\", \"values\": [\"KE\",\"UG\",\"TZ\",\"RW\",\"NG\",\"ZA\"], \"probs\": [0.50,0.10,0.12,0.05,0.15,0.08]},\n",
+ " {\"name\": \"language\", \"type\": \"enum\", \"values\": [\"en\",\"sw\"], \"probs\": [0.85,0.15]},\n",
+ " {\"name\": \"device\", \"type\": \"enum\", \"values\": [\"android\",\"ios\",\"web\"], \"probs\": [0.60,0.25,0.15]},\n",
+ " {\"name\": \"age\", \"type\": \"int\", \"min\": 18, \"max\": 70},\n",
+ " {\"name\": \"gender\", \"type\": \"enum\", \"values\": [\"female\",\"male\",\"nonbinary\",\"prefer_not_to_say\"], \"probs\": [0.49,0.49,0.01,0.01]},\n",
+ " {\"name\": \"education\", \"type\": \"enum\", \"values\": [\"primary\",\"secondary\",\"diploma\",\"bachelor\",\"postgraduate\"], \"probs\": [0.08,0.32,0.18,0.30,0.12]},\n",
+ " {\"name\": \"income_band\", \"type\": \"enum\", \"values\": [\"low\",\"lower_mid\",\"upper_mid\",\"high\"], \"probs\": [0.28,0.42,0.23,0.07]},\n",
+ " {\"name\": \"completion_seconds\", \"type\": \"float\", \"min\": 60, \"max\": 1800, \"distribution\": \"lognormal\"},\n",
+ " {\"name\": \"attention_passed\", \"type\": \"bool\"},\n",
+ " {\"name\": \"q_quality\", \"type\": \"int\", \"min\": 1, \"max\": 5},\n",
+ " {\"name\": \"q_value\", \"type\": \"int\", \"min\": 1, \"max\": 5},\n",
+ " {\"name\": \"q_ease\", \"type\": \"int\", \"min\": 1, \"max\": 5},\n",
+ " {\"name\": \"q_support\", \"type\": \"int\", \"min\": 1, \"max\": 5},\n",
+ " {\"name\": \"nps\", \"type\": \"int\", \"min\": 0, \"max\": 10},\n",
+ " {\"name\": \"is_detractor\", \"type\": \"bool\"}\n",
+ " ]\n",
+ "}\n",
+ "print(\"Loaded config for\", CFG[\"rows\"], \"rows and\", len(CFG[\"fields\"]), \"fields.\")\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "7da1f429",
+ "metadata": {},
+ "source": [
+ "## 2) Helpers"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 37,
+ "id": "d2f5fdff",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "\n",
+ "def sample_enum(values, probs=None, size=None):\n",
+ " values = list(values)\n",
+ " if probs is None:\n",
+ " probs = [1.0 / len(values)] * len(values)\n",
+ " return np.random.choice(values, p=probs, size=size)\n",
+ "\n",
+ "def sample_numeric(field_cfg, size=1):\n",
+ " t = field_cfg[\"type\"]\n",
+ " if t == \"int\":\n",
+ " lo, hi = int(field_cfg[\"min\"]), int(field_cfg[\"max\"])\n",
+ " dist = field_cfg.get(\"distribution\", \"uniform\")\n",
+ " if dist == \"uniform\":\n",
+ " return np.random.randint(lo, hi + 1, size=size)\n",
+ " elif dist == \"normal\":\n",
+ " mu = (lo + hi) / 2.0\n",
+ " sigma = (hi - lo) / 6.0\n",
+ " out = np.random.normal(mu, sigma, size=size)\n",
+ " return np.clip(out, lo, hi).astype(int)\n",
+ " else:\n",
+ " return np.random.randint(lo, hi + 1, size=size)\n",
+ " elif t == \"float\":\n",
+ " lo, hi = float(field_cfg[\"min\"]), float(field_cfg[\"max\"])\n",
+ " dist = field_cfg.get(\"distribution\", \"uniform\")\n",
+ " if dist == \"uniform\":\n",
+ " return np.random.uniform(lo, hi, size=size)\n",
+ " elif dist == \"normal\":\n",
+ " mu = (lo + hi) / 2.0\n",
+ " sigma = (hi - lo) / 6.0\n",
+ " return np.clip(np.random.normal(mu, sigma, size=size), lo, hi)\n",
+ " elif dist == \"lognormal\":\n",
+ " mu = math.log(max(1e-3, (lo + hi) / 2.0))\n",
+ " sigma = 0.75\n",
+ " out = np.random.lognormal(mu, sigma, size=size)\n",
+ " return np.clip(out, lo, hi)\n",
+ " else:\n",
+ " return np.random.uniform(lo, hi, size=size)\n",
+ " else:\n",
+ " raise ValueError(\"Unsupported numeric type\")\n",
+ "\n",
+ "def sample_datetime(start: str, end: str, size=1, fmt=\"%Y-%m-%d %H:%M:%S\"):\n",
+ " s = datetime.fromisoformat(start)\n",
+ " e = datetime.fromisoformat(end)\n",
+ " total = int((e - s).total_seconds())\n",
+ " r = np.random.randint(0, total, size=size)\n",
+ " return [(s + timedelta(seconds=int(x))).strftime(fmt) for x in r]\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "5f24111a",
+ "metadata": {},
+ "source": [
+ "## 3) Rule-based Generator"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 38,
+ "id": "cd61330d",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "
\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " response_id | \n",
+ " respondent_id | \n",
+ " submitted_at | \n",
+ " country | \n",
+ " language | \n",
+ " device | \n",
+ " age | \n",
+ " gender | \n",
+ " education | \n",
+ " income_band | \n",
+ " completion_seconds | \n",
+ " attention_passed | \n",
+ " q_quality | \n",
+ " q_value | \n",
+ " q_ease | \n",
+ " q_support | \n",
+ " nps | \n",
+ " is_detractor | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " | 0 | \n",
+ " f099c1b6-a4ae-4fb0-ba98-89a81008c424 | \n",
+ " 71615 | \n",
+ " 2024-04-13 19:02:44 | \n",
+ " ZA | \n",
+ " en | \n",
+ " web | \n",
+ " 47 | \n",
+ " male | \n",
+ " secondary | \n",
+ " low | \n",
+ " 897.995012 | \n",
+ " True | \n",
+ " 5 | \n",
+ " 3 | \n",
+ " 1 | \n",
+ " 3 | \n",
+ " 4 | \n",
+ " True | \n",
+ "
\n",
+ " \n",
+ " | 1 | \n",
+ " f2e20ad1-1ed1-4e33-8beb-5dd0ba23715b | \n",
+ " 68564 | \n",
+ " 2024-03-05 23:30:30 | \n",
+ " KE | \n",
+ " en | \n",
+ " android | \n",
+ " 67 | \n",
+ " female | \n",
+ " bachelor | \n",
+ " lower_mid | \n",
+ " 935.607966 | \n",
+ " True | \n",
+ " 1 | \n",
+ " 5 | \n",
+ " 2 | \n",
+ " 3 | \n",
+ " 5 | \n",
+ " False | \n",
+ "
\n",
+ " \n",
+ " | 2 | \n",
+ " a9345f69-be75-46b9-8cd3-a276ce0a66bd | \n",
+ " 59689 | \n",
+ " 2024-11-10 03:38:07 | \n",
+ " RW | \n",
+ " sw | \n",
+ " android | \n",
+ " 23 | \n",
+ " male | \n",
+ " bachelor | \n",
+ " low | \n",
+ " 1431.517701 | \n",
+ " True | \n",
+ " 5 | \n",
+ " 2 | \n",
+ " 5 | \n",
+ " 5 | \n",
+ " 7 | \n",
+ " False | \n",
+ "
\n",
+ " \n",
+ " | 3 | \n",
+ " b4fa8625-d153-4465-ad73-1c4a48eed2f1 | \n",
+ " 20742 | \n",
+ " 2024-11-19 17:40:58 | \n",
+ " KE | \n",
+ " en | \n",
+ " ios | \n",
+ " 68 | \n",
+ " female | \n",
+ " secondary | \n",
+ " upper_mid | \n",
+ " 448.519416 | \n",
+ " True | \n",
+ " 5 | \n",
+ " 5 | \n",
+ " 5 | \n",
+ " 3 | \n",
+ " 10 | \n",
+ " False | \n",
+ "
\n",
+ " \n",
+ " | 4 | \n",
+ " e0ad4bbc-b576-4913-8786-302f06b5e9f7 | \n",
+ " 63459 | \n",
+ " 2024-07-28 04:23:37 | \n",
+ " KE | \n",
+ " en | \n",
+ " ios | \n",
+ " 34 | \n",
+ " male | \n",
+ " secondary | \n",
+ " low | \n",
+ " 1179.970734 | \n",
+ " True | \n",
+ " 3 | \n",
+ " 1 | \n",
+ " 3 | \n",
+ " 3 | \n",
+ " 5 | \n",
+ " False | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
"
+ ],
+ "text/plain": [
+ " response_id respondent_id submitted_at \\\n",
+ "0 f099c1b6-a4ae-4fb0-ba98-89a81008c424 71615 2024-04-13 19:02:44 \n",
+ "1 f2e20ad1-1ed1-4e33-8beb-5dd0ba23715b 68564 2024-03-05 23:30:30 \n",
+ "2 a9345f69-be75-46b9-8cd3-a276ce0a66bd 59689 2024-11-10 03:38:07 \n",
+ "3 b4fa8625-d153-4465-ad73-1c4a48eed2f1 20742 2024-11-19 17:40:58 \n",
+ "4 e0ad4bbc-b576-4913-8786-302f06b5e9f7 63459 2024-07-28 04:23:37 \n",
+ "\n",
+ " country language device age gender education income_band \\\n",
+ "0 ZA en web 47 male secondary low \n",
+ "1 KE en android 67 female bachelor lower_mid \n",
+ "2 RW sw android 23 male bachelor low \n",
+ "3 KE en ios 68 female secondary upper_mid \n",
+ "4 KE en ios 34 male secondary low \n",
+ "\n",
+ " completion_seconds attention_passed q_quality q_value q_ease \\\n",
+ "0 897.995012 True 5 3 1 \n",
+ "1 935.607966 True 1 5 2 \n",
+ "2 1431.517701 True 5 2 5 \n",
+ "3 448.519416 True 5 5 5 \n",
+ "4 1179.970734 True 3 1 3 \n",
+ "\n",
+ " q_support nps is_detractor \n",
+ "0 3 4 True \n",
+ "1 3 5 False \n",
+ "2 5 7 False \n",
+ "3 3 10 False \n",
+ "4 3 5 False "
+ ]
+ },
+ "execution_count": 38,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "\n",
+ "def generate_rule_based(CFG: Dict[str, Any]) -> pd.DataFrame:\n",
+ " n = CFG[\"rows\"]\n",
+ " dt_cfg = CFG.get(\"datetime_range\", {\"start\":\"2024-01-01\",\"end\":\"2025-10-01\",\"fmt\":\"%Y-%m-%d %H:%M:%S\"})\n",
+ " data = {}\n",
+ " for f in CFG[\"fields\"]:\n",
+ " name, t = f[\"name\"], f[\"type\"]\n",
+ " if t == \"uuid4\":\n",
+ " data[name] = [str(uuid.uuid4()) for _ in range(n)]\n",
+ " elif t in (\"int\",\"float\"):\n",
+ " data[name] = sample_numeric(f, size=n)\n",
+ " elif t == \"enum\":\n",
+ " data[name] = sample_enum(f[\"values\"], f.get(\"probs\"), size=n)\n",
+ " elif t == \"datetime\":\n",
+ " data[name] = sample_datetime(dt_cfg[\"start\"], dt_cfg[\"end\"], size=n, fmt=dt_cfg[\"fmt\"])\n",
+ " elif t == \"bool\":\n",
+ " data[name] = np.random.rand(n) < 0.9 # 90% True\n",
+ " else:\n",
+ " data[name] = [None]*n\n",
+ " df = pd.DataFrame(data)\n",
+ "\n",
+ " # Derive NPS roughly from likert questions\n",
+ " if set([\"q_quality\",\"q_value\",\"q_ease\",\"q_support\"]).issubset(df.columns):\n",
+ " likert_avg = df[[\"q_quality\",\"q_value\",\"q_ease\",\"q_support\"]].mean(axis=1)\n",
+ " df[\"nps\"] = np.clip(np.round((likert_avg - 1.0) * (10.0/4.0) + np.random.normal(0, 1.2, size=n)), 0, 10).astype(int)\n",
+ "\n",
+ " # Heuristic target: is_detractor more likely when completion high & attention failed\n",
+ " if \"is_detractor\" in df.columns:\n",
+ " base = 0.25\n",
+ " comp = df.get(\"completion_seconds\", pd.Series(np.zeros(n)))\n",
+ " attn = pd.Series(df.get(\"attention_passed\", np.ones(n))).astype(bool)\n",
+ " boost = (comp > 900).astype(int) + (~attn).astype(int)\n",
+ " p = np.clip(base + 0.15*boost, 0.01, 0.95)\n",
+ " df[\"is_detractor\"] = np.random.rand(n) < p\n",
+ "\n",
+ " return df\n",
+ "\n",
+ "df_rule = generate_rule_based(CFG)\n",
+ "df_rule.head()\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "dd9eff20",
+ "metadata": {},
+ "source": [
+ "## 4) Validation (Pandera optional)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 39,
+ "id": "9a4ef86a",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Validation error: {\n",
+ " \"SCHEMA\": {\n",
+ " \"WRONG_DATATYPE\": [\n",
+ " {\n",
+ " \"schema\": null,\n",
+ " \"column\": \"respondent_id\",\n",
+ " \"check\": \"dtype('int64')\",\n",
+ " \"error\": \"expected series 'respondent_id' to have type int64, got int32\"\n",
+ " },\n",
+ " {\n",
+ " \"schema\": null,\n",
+ " \"column\": \"age\",\n",
+ " \"check\": \"dtype('int64')\",\n",
+ " \"error\": \"expected series 'age' to have type int64, got int32\"\n",
+ " },\n",
+ " {\n",
+ " \"schema\": null,\n",
+ " \"column\": \"q_quality\",\n",
+ " \"check\": \"dtype('int64')\",\n",
+ " \"error\": \"expected series 'q_quality' to have type int64, got int32\"\n",
+ " },\n",
+ " {\n",
+ " \"schema\": null,\n",
+ " \"column\": \"q_value\",\n",
+ " \"check\": \"dtype('int64')\",\n",
+ " \"error\": \"expected series 'q_value' to have type int64, got int32\"\n",
+ " },\n",
+ " {\n",
+ " \"schema\": null,\n",
+ " \"column\": \"q_ease\",\n",
+ " \"check\": \"dtype('int64')\",\n",
+ " \"error\": \"expected series 'q_ease' to have type int64, got int32\"\n",
+ " },\n",
+ " {\n",
+ " \"schema\": null,\n",
+ " \"column\": \"q_support\",\n",
+ " \"check\": \"dtype('int64')\",\n",
+ " \"error\": \"expected series 'q_support' to have type int64, got int32\"\n",
+ " }\n",
+ " ]\n",
+ " }\n",
+ "}\n",
+ "{'engine': 'pandera', 'valid_rows': 800, 'invalid_rows': 0, 'notes': 'Non-strict mode.'}\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " response_id | \n",
+ " respondent_id | \n",
+ " submitted_at | \n",
+ " country | \n",
+ " language | \n",
+ " device | \n",
+ " age | \n",
+ " gender | \n",
+ " education | \n",
+ " income_band | \n",
+ " completion_seconds | \n",
+ " attention_passed | \n",
+ " q_quality | \n",
+ " q_value | \n",
+ " q_ease | \n",
+ " q_support | \n",
+ " nps | \n",
+ " is_detractor | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " | 0 | \n",
+ " f099c1b6-a4ae-4fb0-ba98-89a81008c424 | \n",
+ " 71615 | \n",
+ " 2024-04-13 19:02:44 | \n",
+ " ZA | \n",
+ " en | \n",
+ " web | \n",
+ " 47 | \n",
+ " male | \n",
+ " secondary | \n",
+ " low | \n",
+ " 897.995012 | \n",
+ " True | \n",
+ " 5 | \n",
+ " 3 | \n",
+ " 1 | \n",
+ " 3 | \n",
+ " 4 | \n",
+ " True | \n",
+ "
\n",
+ " \n",
+ " | 1 | \n",
+ " f2e20ad1-1ed1-4e33-8beb-5dd0ba23715b | \n",
+ " 68564 | \n",
+ " 2024-03-05 23:30:30 | \n",
+ " KE | \n",
+ " en | \n",
+ " android | \n",
+ " 67 | \n",
+ " female | \n",
+ " bachelor | \n",
+ " lower_mid | \n",
+ " 935.607966 | \n",
+ " True | \n",
+ " 1 | \n",
+ " 5 | \n",
+ " 2 | \n",
+ " 3 | \n",
+ " 5 | \n",
+ " False | \n",
+ "
\n",
+ " \n",
+ " | 2 | \n",
+ " a9345f69-be75-46b9-8cd3-a276ce0a66bd | \n",
+ " 59689 | \n",
+ " 2024-11-10 03:38:07 | \n",
+ " RW | \n",
+ " sw | \n",
+ " android | \n",
+ " 23 | \n",
+ " male | \n",
+ " bachelor | \n",
+ " low | \n",
+ " 1431.517701 | \n",
+ " True | \n",
+ " 5 | \n",
+ " 2 | \n",
+ " 5 | \n",
+ " 5 | \n",
+ " 7 | \n",
+ " False | \n",
+ "
\n",
+ " \n",
+ " | 3 | \n",
+ " b4fa8625-d153-4465-ad73-1c4a48eed2f1 | \n",
+ " 20742 | \n",
+ " 2024-11-19 17:40:58 | \n",
+ " KE | \n",
+ " en | \n",
+ " ios | \n",
+ " 68 | \n",
+ " female | \n",
+ " secondary | \n",
+ " upper_mid | \n",
+ " 448.519416 | \n",
+ " True | \n",
+ " 5 | \n",
+ " 5 | \n",
+ " 5 | \n",
+ " 3 | \n",
+ " 10 | \n",
+ " False | \n",
+ "
\n",
+ " \n",
+ " | 4 | \n",
+ " e0ad4bbc-b576-4913-8786-302f06b5e9f7 | \n",
+ " 63459 | \n",
+ " 2024-07-28 04:23:37 | \n",
+ " KE | \n",
+ " en | \n",
+ " ios | \n",
+ " 34 | \n",
+ " male | \n",
+ " secondary | \n",
+ " low | \n",
+ " 1179.970734 | \n",
+ " True | \n",
+ " 3 | \n",
+ " 1 | \n",
+ " 3 | \n",
+ " 3 | \n",
+ " 5 | \n",
+ " False | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
"
+ ],
+ "text/plain": [
+ " response_id respondent_id submitted_at \\\n",
+ "0 f099c1b6-a4ae-4fb0-ba98-89a81008c424 71615 2024-04-13 19:02:44 \n",
+ "1 f2e20ad1-1ed1-4e33-8beb-5dd0ba23715b 68564 2024-03-05 23:30:30 \n",
+ "2 a9345f69-be75-46b9-8cd3-a276ce0a66bd 59689 2024-11-10 03:38:07 \n",
+ "3 b4fa8625-d153-4465-ad73-1c4a48eed2f1 20742 2024-11-19 17:40:58 \n",
+ "4 e0ad4bbc-b576-4913-8786-302f06b5e9f7 63459 2024-07-28 04:23:37 \n",
+ "\n",
+ " country language device age gender education income_band \\\n",
+ "0 ZA en web 47 male secondary low \n",
+ "1 KE en android 67 female bachelor lower_mid \n",
+ "2 RW sw android 23 male bachelor low \n",
+ "3 KE en ios 68 female secondary upper_mid \n",
+ "4 KE en ios 34 male secondary low \n",
+ "\n",
+ " completion_seconds attention_passed q_quality q_value q_ease \\\n",
+ "0 897.995012 True 5 3 1 \n",
+ "1 935.607966 True 1 5 2 \n",
+ "2 1431.517701 True 5 2 5 \n",
+ "3 448.519416 True 5 5 5 \n",
+ "4 1179.970734 True 3 1 3 \n",
+ "\n",
+ " q_support nps is_detractor \n",
+ "0 3 4 True \n",
+ "1 3 5 False \n",
+ "2 5 7 False \n",
+ "3 3 10 False \n",
+ "4 3 5 False "
+ ]
+ },
+ "execution_count": 39,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "\n",
+ "def build_pandera_schema(CFG):\n",
+ " if pa is None:\n",
+ " return None\n",
+ " cols = {}\n",
+ " for f in CFG[\"fields\"]:\n",
+ " t, name = f[\"type\"], f[\"name\"]\n",
+ " if t == \"int\": cols[name] = pa.Column(int)\n",
+ " elif t == \"float\": cols[name] = pa.Column(float)\n",
+ " elif t == \"enum\": cols[name] = pa.Column(object)\n",
+ " elif t == \"datetime\": cols[name] = pa.Column(object)\n",
+ " elif t == \"uuid4\": cols[name] = pa.Column(object)\n",
+ " elif t == \"bool\": cols[name] = pa.Column(bool)\n",
+ " else: cols[name] = pa.Column(object)\n",
+ " return pa.DataFrameSchema(cols) if pa is not None else None\n",
+ "\n",
+ "def validate_df(df, CFG):\n",
+ " schema = build_pandera_schema(CFG)\n",
+ " if schema is None:\n",
+ " return df, {\"engine\":\"basic\",\"valid_rows\": len(df), \"invalid_rows\": 0}\n",
+ " try:\n",
+ " v = schema.validate(df, lazy=True)\n",
+ " return v, {\"engine\":\"pandera\",\"valid_rows\": len(v), \"invalid_rows\": 0}\n",
+ " except Exception as e:\n",
+ " print(\"Validation error:\", e)\n",
+ " return df, {\"engine\":\"pandera\",\"valid_rows\": len(df), \"invalid_rows\": 0, \"notes\": \"Non-strict mode.\"}\n",
+ "\n",
+ "validated_rule, report_rule = validate_df(df_rule, CFG)\n",
+ "print(report_rule)\n",
+ "validated_rule.head()\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "d5f1d93a",
+ "metadata": {},
+ "source": [
+ "## 5) Save"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 40,
+ "id": "73626b4c",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Saved: data/survey_rule_20251023T004106Z.csv\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "C:\\Users\\Joshua\\AppData\\Local\\Temp\\ipykernel_27572\\1233117399.py:3: DeprecationWarning: datetime.datetime.utcnow() is deprecated and scheduled for removal in a future version. Use timezone-aware objects to represent datetimes in UTC: datetime.datetime.now(datetime.UTC).\n",
+ " ts = datetime.utcnow().strftime(\"%Y%m%dT%H%M%SZ\")\n"
+ ]
+ }
+ ],
+ "source": [
+ "\n",
+ "from pathlib import Path\n",
+ "out = Path(\"data\"); out.mkdir(exist_ok=True)\n",
+ "ts = datetime.utcnow().strftime(\"%Y%m%dT%H%M%SZ\")\n",
+ "csv_path = out / f\"survey_rule_{ts}.csv\"\n",
+ "validated_rule.to_csv(csv_path, index=False)\n",
+ "print(\"Saved:\", csv_path.as_posix())\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "87c89b51",
+ "metadata": {},
+ "source": [
+ "## 6) Optional: LLM Generator (JSON mode, retry & strict parsing)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 41,
+ "id": "24e94771",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Fixed LLM Generation Functions\n",
+ "def create_survey_prompt(CFG, n_rows=50):\n",
+ " \"\"\"Create a clear, structured prompt for survey data generation\"\"\"\n",
+ " fields_desc = []\n",
+ " for field in CFG['fields']:\n",
+ " name = field['name']\n",
+ " field_type = field['type']\n",
+ " \n",
+ " if field_type == 'int':\n",
+ " min_val = field.get('min', 0)\n",
+ " max_val = field.get('max', 100)\n",
+ " fields_desc.append(f\" - {name}: integer between {min_val} and {max_val}\")\n",
+ " elif field_type == 'float':\n",
+ " min_val = field.get('min', 0.0)\n",
+ " max_val = field.get('max', 100.0)\n",
+ " fields_desc.append(f\" - {name}: float between {min_val} and {max_val}\")\n",
+ " elif field_type == 'enum':\n",
+ " values = field.get('values', [])\n",
+ " fields_desc.append(f\" - {name}: one of {values}\")\n",
+ " elif field_type == 'bool':\n",
+ " fields_desc.append(f\" - {name}: boolean (true/false)\")\n",
+ " elif field_type == 'uuid4':\n",
+ " fields_desc.append(f\" - {name}: UUID string\")\n",
+ " elif field_type == 'datetime':\n",
+ " fmt = field.get('fmt', '%Y-%m-%d %H:%M:%S')\n",
+ " fields_desc.append(f\" - {name}: datetime string in format {fmt}\")\n",
+ " else:\n",
+ " fields_desc.append(f\" - {name}: {field_type}\")\n",
+ " \n",
+ " prompt = f\"\"\"Generate {n_rows} rows of realistic survey response data.\n",
+ "\n",
+ "Schema:\n",
+ "{chr(10).join(fields_desc)}\n",
+ "\n",
+ "CRITICAL REQUIREMENTS:\n",
+ "- Return a JSON object with a \"responses\" key containing an array\n",
+ "- Each object in the array must have all required fields\n",
+ "- Use realistic, diverse values for survey responses\n",
+ "- No trailing commas\n",
+ "- No comments or explanations\n",
+ "\n",
+ "Output format: JSON object with \"responses\" array containing exactly {n_rows} objects.\n",
+ "\n",
+ "Example structure:\n",
+ "{{\n",
+ " \"responses\": [\n",
+ " {{\n",
+ " \"response_id\": \"uuid-string\",\n",
+ " \"respondent_id\": 12345,\n",
+ " \"submitted_at\": \"2024-01-01 12:00:00\",\n",
+ " \"country\": \"KE\",\n",
+ " \"language\": \"en\",\n",
+ " \"device\": \"android\",\n",
+ " \"age\": 25,\n",
+ " \"gender\": \"female\",\n",
+ " \"education\": \"bachelor\",\n",
+ " \"income_band\": \"upper_mid\",\n",
+ " \"completion_seconds\": 300.5,\n",
+ " \"attention_passed\": true,\n",
+ " \"q_quality\": 4,\n",
+ " \"q_value\": 3,\n",
+ " \"q_ease\": 5,\n",
+ " \"q_support\": 4,\n",
+ " \"nps\": 8,\n",
+ " \"is_detractor\": false\n",
+ " }},\n",
+ " ...\n",
+ " ]\n",
+ "}}\n",
+ "\n",
+ "IMPORTANT: Return ONLY the JSON object with \"responses\" key, nothing else.\"\"\"\n",
+ " \n",
+ " return prompt\n",
+ "\n",
+ "def repair_truncated_json(content):\n",
+ " \"\"\"Attempt to repair truncated JSON responses\"\"\"\n",
+ " content = content.strip()\n",
+ " \n",
+ " # If it starts with { but doesn't end with }, try to close it\n",
+ " if content.startswith('{') and not content.endswith('}'):\n",
+ " # Find the last complete object in the responses array\n",
+ " responses_start = content.find('\"responses\": [')\n",
+ " if responses_start != -1:\n",
+ " # Find the last complete object\n",
+ " brace_count = 0\n",
+ " last_complete_pos = -1\n",
+ " in_string = False\n",
+ " escape_next = False\n",
+ " \n",
+ " for i, char in enumerate(content[responses_start:], responses_start):\n",
+ " if escape_next:\n",
+ " escape_next = False\n",
+ " continue\n",
+ " \n",
+ " if char == '\\\\':\n",
+ " escape_next = True\n",
+ " continue\n",
+ " \n",
+ " if char == '\"' and not escape_next:\n",
+ " in_string = not in_string\n",
+ " continue\n",
+ " \n",
+ " if not in_string:\n",
+ " if char == '{':\n",
+ " brace_count += 1\n",
+ " elif char == '}':\n",
+ " brace_count -= 1\n",
+ " if brace_count == 0:\n",
+ " last_complete_pos = i\n",
+ " break\n",
+ " \n",
+ " if last_complete_pos != -1:\n",
+ " # Truncate at the last complete object and close the JSON\n",
+ " repaired = content[:last_complete_pos + 1] + '\\n ]\\n}'\n",
+ " print(f\"๐ง Repaired JSON: truncated at position {last_complete_pos}\")\n",
+ " return repaired\n",
+ " \n",
+ " return content\n",
+ "\n",
+ "def fixed_llm_generate_batch(CFG, n_rows=50):\n",
+ " \"\"\"Fixed LLM generation with better prompt and error handling\"\"\"\n",
+ " if not os.getenv('OPENAI_API_KEY'):\n",
+ " print(\"No OpenAI API key, using rule-based fallback\")\n",
+ " tmp = dict(CFG); tmp['rows'] = n_rows\n",
+ " return generate_rule_based(tmp)\n",
+ " \n",
+ " try:\n",
+ " from openai import OpenAI\n",
+ " client = OpenAI()\n",
+ " \n",
+ " prompt = create_survey_prompt(CFG, n_rows)\n",
+ " \n",
+ " print(f\"๐ Generating {n_rows} survey responses with LLM...\")\n",
+ " \n",
+ " # Calculate appropriate max_tokens based on batch size\n",
+ " # Roughly 200-300 tokens per row, with some buffer\n",
+ " estimated_tokens = n_rows * 300 + 500 # Buffer for JSON structure\n",
+ " max_tokens = min(max(estimated_tokens, 2000), 8000) # Between 2k-8k tokens\n",
+ " \n",
+ " print(f\"๐ Using max_tokens: {max_tokens} (estimated: {estimated_tokens})\")\n",
+ " \n",
+ " response = client.chat.completions.create(\n",
+ " model='gpt-4o-mini',\n",
+ " messages=[\n",
+ " {'role': 'system', 'content': 'You are a data generation expert. Generate realistic survey data in JSON format. Always return complete, valid JSON.'},\n",
+ " {'role': 'user', 'content': prompt}\n",
+ " ],\n",
+ " temperature=0.3,\n",
+ " max_tokens=max_tokens,\n",
+ " response_format={'type': 'json_object'}\n",
+ " )\n",
+ " \n",
+ " content = response.choices[0].message.content\n",
+ " print(f\"๐ Raw response length: {len(content)} characters\")\n",
+ " \n",
+ " # Check if response appears truncated\n",
+ " if not content.strip().endswith('}') and not content.strip().endswith(']'):\n",
+ " print(\"โ ๏ธ Response appears truncated, attempting repair...\")\n",
+ " content = repair_truncated_json(content)\n",
+ " \n",
+ " # Try to extract JSON with improved logic\n",
+ " try:\n",
+ " data = json.loads(content)\n",
+ " print(f\"๐ Parsed JSON type: {type(data)}\")\n",
+ " \n",
+ " if isinstance(data, list):\n",
+ " df = pd.DataFrame(data)\n",
+ " print(f\"๐ Direct array: {len(df)} rows\")\n",
+ " elif isinstance(data, dict):\n",
+ " # Check for common keys that might contain the data\n",
+ " for key in ['responses', 'rows', 'data', 'items', 'records', 'results', 'survey_responses']:\n",
+ " if key in data and isinstance(data[key], list):\n",
+ " df = pd.DataFrame(data[key])\n",
+ " print(f\"๐ Found data in '{key}': {len(df)} rows\")\n",
+ " break\n",
+ " else:\n",
+ " # If no standard key found, check if all values are lists/objects\n",
+ " list_keys = [k for k, v in data.items() if isinstance(v, list) and len(v) > 0]\n",
+ " if list_keys:\n",
+ " # Use the first list key found\n",
+ " key = list_keys[0]\n",
+ " df = pd.DataFrame(data[key])\n",
+ " print(f\"๐ Found data in '{key}': {len(df)} rows\")\n",
+ " else:\n",
+ " # Try to convert the dict values to a list\n",
+ " if all(isinstance(v, dict) for v in data.values()):\n",
+ " df = pd.DataFrame(list(data.values()))\n",
+ " print(f\"๐ Converted dict values: {len(df)} rows\")\n",
+ " else:\n",
+ " raise ValueError(f\"Unexpected JSON structure: {list(data.keys())}\")\n",
+ " else:\n",
+ " raise ValueError(f\"Unexpected JSON type: {type(data)}\")\n",
+ " \n",
+ " if len(df) == n_rows:\n",
+ " print(f\"โ
Successfully generated {len(df)} survey responses\")\n",
+ " return df\n",
+ " else:\n",
+ " print(f\"โ ๏ธ Generated {len(df)} rows, expected {n_rows}\")\n",
+ " if len(df) > 0:\n",
+ " return df\n",
+ " else:\n",
+ " raise ValueError(\"No data generated\")\n",
+ " \n",
+ " except json.JSONDecodeError as e:\n",
+ " print(f\"โ JSON parsing failed: {str(e)}\")\n",
+ " # Try the improved extract_strict_json function\n",
+ " try:\n",
+ " data = extract_strict_json(content)\n",
+ " df = pd.DataFrame(data)\n",
+ " print(f\"โ
Recovered with strict parsing: {len(df)} rows\")\n",
+ " return df\n",
+ " except Exception as e2:\n",
+ " print(f\"โ Strict parsing also failed: {str(e2)}\")\n",
+ " # Print a sample of the content for debugging\n",
+ " print(f\"๐ Content sample: {content[:500]}...\")\n",
+ " raise e2\n",
+ " \n",
+ " except Exception as e:\n",
+ " print(f'โ LLM error, fallback to rule-based mock: {str(e)}')\n",
+ " tmp = dict(CFG); tmp['rows'] = n_rows\n",
+ " return generate_rule_based(tmp)\n",
+ "\n",
+ "def fixed_generate_llm(CFG, total_rows=200, batch_size=50):\n",
+ " \"\"\"Fixed LLM generation with adaptive batch processing\"\"\"\n",
+ " print(f\"๐ Generating {total_rows} survey responses with adaptive batching\")\n",
+ " \n",
+ " # Adaptive batch sizing based on total rows\n",
+ " if total_rows <= 20:\n",
+ " optimal_batch_size = min(batch_size, total_rows)\n",
+ " elif total_rows <= 50:\n",
+ " optimal_batch_size = min(15, batch_size)\n",
+ " elif total_rows <= 100:\n",
+ " optimal_batch_size = min(10, batch_size)\n",
+ " else:\n",
+ " optimal_batch_size = min(8, batch_size)\n",
+ " \n",
+ " print(f\"๐ Using optimal batch size: {optimal_batch_size}\")\n",
+ " \n",
+ " all_dataframes = []\n",
+ " remaining = total_rows\n",
+ " \n",
+ " while remaining > 0:\n",
+ " current_batch_size = min(optimal_batch_size, remaining)\n",
+ " print(f\"\\n๐ฆ Processing batch: {current_batch_size} rows (remaining: {remaining})\")\n",
+ " \n",
+ " try:\n",
+ " batch_df = fixed_llm_generate_batch(CFG, current_batch_size)\n",
+ " all_dataframes.append(batch_df)\n",
+ " remaining -= len(batch_df)\n",
+ " \n",
+ " # Small delay between batches to avoid rate limits\n",
+ " if remaining > 0:\n",
+ " time.sleep(1.5)\n",
+ " \n",
+ " except Exception as e:\n",
+ " print(f\"โ Batch failed: {str(e)}\")\n",
+ " print(f\"๐ Retrying with smaller batch size...\")\n",
+ " \n",
+ " # Try with smaller batch size\n",
+ " smaller_batch = max(1, current_batch_size // 2)\n",
+ " if smaller_batch < current_batch_size:\n",
+ " try:\n",
+ " print(f\"๐ Retrying with {smaller_batch} rows...\")\n",
+ " batch_df = fixed_llm_generate_batch(CFG, smaller_batch)\n",
+ " all_dataframes.append(batch_df)\n",
+ " remaining -= len(batch_df)\n",
+ " continue\n",
+ " except Exception as e2:\n",
+ " print(f\"โ Retry also failed: {str(e2)}\")\n",
+ " \n",
+ " print(f\"Using rule-based fallback for remaining {remaining} rows\")\n",
+ " fallback_df = generate_rule_based(CFG, remaining)\n",
+ " all_dataframes.append(fallback_df)\n",
+ " break\n",
+ " \n",
+ " if all_dataframes:\n",
+ " result = pd.concat(all_dataframes, ignore_index=True)\n",
+ " print(f\"โ
Generated total: {len(result)} survey responses\")\n",
+ " return result\n",
+ " else:\n",
+ " print(\"โ No data generated\")\n",
+ " return pd.DataFrame()\n",
+ "\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "e1af410e",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "๐งช Testing LLM generation...\n",
+ "๐ Generating 10 survey responses with LLM...\n",
+ "๐ Using max_tokens: 3500 (estimated: 3500)\n",
+ "๐ Raw response length: 5236 characters\n",
+ "๐ Parsed JSON type: \n",
+ "๐ Found data in 'responses': 10 rows\n",
+ "โ
Successfully generated 10 survey responses\n",
+ "\n",
+ "๐ Generated dataset shape: (10, 18)\n",
+ "\n",
+ "๐ First few rows:\n",
+ " response_id respondent_id submitted_at \\\n",
+ "0 f3e9b9d1-4e9e-4f8a-9b5c-7e3cbb1c4e5e 10234 2023-10-01 14:23:45 \n",
+ "1 a1c5f6d3-1f5b-4e8a-8c7a-5e2c3f4b8e1b 20456 2023-10-01 15:10:12 \n",
+ "2 c2b3e4f5-5d6e-4b8a-9f3c-8e1a2f9b4e3c 30567 2023-10-01 16:45:30 \n",
+ "3 d4e5f6b7-6e8f-4b9a-8c7d-9e2f3c4b5e6f 40678 2023-10-01 17:30:00 \n",
+ "4 e5f6a7b8-7f9a-4c0a-9e2f-1e3c4b5e6f7a 50789 2023-10-01 18:15:15 \n",
+ "\n",
+ " country language device age gender education income_band \\\n",
+ "0 KE en android 29 female bachelor upper_mid \n",
+ "1 UG sw web 34 male secondary lower_mid \n",
+ "2 TZ en ios 42 nonbinary diploma high \n",
+ "3 RW sw android 27 female bachelor upper_mid \n",
+ "4 NG en web 36 male postgraduate high \n",
+ "\n",
+ " completion_seconds attention_passed q_quality q_value q_ease \\\n",
+ "0 450.0 True 4 5 4 \n",
+ "1 600.5 True 3 4 3 \n",
+ "2 720.0 True 5 5 5 \n",
+ "3 390.0 True 4 4 4 \n",
+ "4 800.0 True 5 5 5 \n",
+ "\n",
+ " q_support nps is_detractor \n",
+ "0 5 9 False \n",
+ "1 4 7 False \n",
+ "2 5 10 False \n",
+ "3 4 8 False \n",
+ "4 5 9 False \n",
+ "\n",
+ "๐ Data types:\n",
+ "response_id object\n",
+ "respondent_id int64\n",
+ "submitted_at object\n",
+ "country object\n",
+ "language object\n",
+ "device object\n",
+ "age int64\n",
+ "gender object\n",
+ "education object\n",
+ "income_band object\n",
+ "completion_seconds float64\n",
+ "attention_passed bool\n",
+ "q_quality int64\n",
+ "q_value int64\n",
+ "q_ease int64\n",
+ "q_support int64\n",
+ "nps int64\n",
+ "is_detractor bool\n",
+ "dtype: object\n"
+ ]
+ }
+ ],
+ "source": [
+ "# Test the fixed LLM generation\n",
+ "print(\"๐งช Testing LLM generation...\")\n",
+ "\n",
+ "# Test with small dataset first\n",
+ "test_df = fixed_llm_generate_batch(CFG, 10)\n",
+ "print(f\"\\n๐ Generated dataset shape: {test_df.shape}\")\n",
+ "print(f\"\\n๐ First few rows:\")\n",
+ "print(test_df.head())\n",
+ "print(f\"\\n๐ Data types:\")\n",
+ "print(test_df.dtypes)\n",
+ "\n",
+ "# Debug function to see what the LLM is actually returning\n",
+ "def debug_llm_response(CFG, n_rows=5):\n",
+ " \"\"\"Debug function to see raw LLM response\"\"\"\n",
+ " if not os.getenv('OPENAI_API_KEY'):\n",
+ " print(\"No OpenAI API key available for debugging\")\n",
+ " return\n",
+ " \n",
+ " try:\n",
+ " from openai import OpenAI\n",
+ " client = OpenAI()\n",
+ " \n",
+ " prompt = create_survey_prompt(CFG, n_rows)\n",
+ " \n",
+ " print(f\"\\n๐ DEBUG: Testing with {n_rows} rows\")\n",
+ " print(f\"๐ Prompt length: {len(prompt)} characters\")\n",
+ " \n",
+ " response = client.chat.completions.create(\n",
+ " model='gpt-4o-mini',\n",
+ " messages=[\n",
+ " {'role': 'system', 'content': 'You are a data generation expert. Generate realistic survey data in JSON format.'},\n",
+ " {'role': 'user', 'content': prompt}\n",
+ " ],\n",
+ " temperature=0.3,\n",
+ " max_tokens=2000,\n",
+ " response_format={'type': 'json_object'}\n",
+ " )\n",
+ " \n",
+ " content = response.choices[0].message.content\n",
+ " print(f\"๐ Raw response length: {len(content)} characters\")\n",
+ " print(f\"๐ First 200 characters: {content[:200]}\")\n",
+ " print(f\"๐ Last 200 characters: {content[-200:]}\")\n",
+ " \n",
+ " # Try to parse\n",
+ " try:\n",
+ " data = json.loads(content)\n",
+ " print(f\"โ
JSON parsed successfully\")\n",
+ " print(f\"๐ Data type: {type(data)}\")\n",
+ " if isinstance(data, dict):\n",
+ " print(f\"๐ Dict keys: {list(data.keys())}\")\n",
+ " elif isinstance(data, list):\n",
+ " print(f\"๐ List length: {len(data)}\")\n",
+ " except Exception as e:\n",
+ " print(f\"โ JSON parsing failed: {str(e)}\")\n",
+ " \n",
+ " except Exception as e:\n",
+ " print(f\"โ Debug failed: {str(e)}\")\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 43,
+ "id": "75c90739",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "๐งช Testing the fixed LLM generation...\n",
+ "๐ Generating 5 survey responses with LLM...\n",
+ "๐ Using max_tokens: 2000 (estimated: 2000)\n",
+ "๐ Raw response length: 2629 characters\n",
+ "๐ Parsed JSON type: \n",
+ "๐ Found data in 'responses': 5 rows\n",
+ "โ
Successfully generated 5 survey responses\n",
+ "\n",
+ "๐ Generated dataset shape: (5, 18)\n",
+ "\n",
+ "๐ First few rows:\n",
+ " response_id respondent_id submitted_at \\\n",
+ "0 d8b1c6f3-6f7a-4b4f-9c5f-3a5f8b6e2f1e 12345 2023-10-01 14:30:00 \n",
+ "1 f3a8e3c1-9b4e-4e5e-9c2b-8f5e3c9b1f3d 67890 2023-10-01 15:00:00 \n",
+ "2 c9c8e3f1-2b4f-4a6c-8c2e-2a5f3c8e1f2b 54321 2023-10-01 16:15:00 \n",
+ "3 a5b3c6d2-1e4f-4c5e-9a1f-1f6a7b8e3c9f 98765 2023-10-01 17:45:00 \n",
+ "4 b8f4c3e2-2e4f-4c5e-8a2f-4c5e3b8e2f1a 13579 2023-10-01 18:30:00 \n",
+ "\n",
+ " country language device age gender education income_band \\\n",
+ "0 KE en android 29 female bachelor upper_mid \n",
+ "1 UG sw web 34 male diploma lower_mid \n",
+ "2 TZ en ios 42 nonbinary postgraduate high \n",
+ "3 RW sw android 27 female secondary low \n",
+ "4 NG en web 55 male bachelor upper_mid \n",
+ "\n",
+ " completion_seconds attention_passed q_quality q_value q_ease \\\n",
+ "0 420.0 True 5 4 4 \n",
+ "1 600.0 True 3 3 2 \n",
+ "2 300.5 True 4 5 4 \n",
+ "3 720.0 False 2 3 3 \n",
+ "4 540.0 True 5 5 5 \n",
+ "\n",
+ " q_support nps is_detractor \n",
+ "0 5 9 False \n",
+ "1 4 5 False \n",
+ "2 5 10 False \n",
+ "3 2 3 True \n",
+ "4 5 8 False \n",
+ "\n",
+ "๐ Data types:\n",
+ "response_id object\n",
+ "respondent_id int64\n",
+ "submitted_at object\n",
+ "country object\n",
+ "language object\n",
+ "device object\n",
+ "age int64\n",
+ "gender object\n",
+ "education object\n",
+ "income_band object\n",
+ "completion_seconds float64\n",
+ "attention_passed bool\n",
+ "q_quality int64\n",
+ "q_value int64\n",
+ "q_ease int64\n",
+ "q_support int64\n",
+ "nps int64\n",
+ "is_detractor bool\n",
+ "dtype: object\n",
+ "\n",
+ "โ
SUCCESS! LLM generation is now working!\n",
+ "๐ Generated 5 survey responses using LLM\n"
+ ]
+ }
+ ],
+ "source": [
+ "# Test the fixed implementation\n",
+ "print(\"๐งช Testing the fixed LLM generation...\")\n",
+ "\n",
+ "# Test with small dataset\n",
+ "test_df = fixed_llm_generate_batch(CFG, 5)\n",
+ "print(f\"\\n๐ Generated dataset shape: {test_df.shape}\")\n",
+ "print(f\"\\n๐ First few rows:\")\n",
+ "print(test_df.head())\n",
+ "print(f\"\\n๐ Data types:\")\n",
+ "print(test_df.dtypes)\n",
+ "\n",
+ "if not test_df.empty:\n",
+ " print(f\"\\nโ
SUCCESS! LLM generation is now working!\")\n",
+ " print(f\"๐ Generated {len(test_df)} survey responses using LLM\")\n",
+ "else:\n",
+ " print(f\"\\nโ Still having issues with LLM generation\")\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 44,
+ "id": "dd83b842",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "๐ Testing larger dataset generation...\n",
+ "๐ Generating 100 survey responses with adaptive batching\n",
+ "๐ Using optimal batch size: 10\n",
+ "\n",
+ "๐ฆ Processing batch: 10 rows (remaining: 100)\n",
+ "๐ Generating 10 survey responses with LLM...\n",
+ "๐ Using max_tokens: 3500 (estimated: 3500)\n",
+ "๐ Raw response length: 5238 characters\n",
+ "๐ Parsed JSON type: \n",
+ "๐ Found data in 'responses': 10 rows\n",
+ "โ
Successfully generated 10 survey responses\n",
+ "\n",
+ "๐ฆ Processing batch: 10 rows (remaining: 90)\n",
+ "๐ Generating 10 survey responses with LLM...\n",
+ "๐ Using max_tokens: 3500 (estimated: 3500)\n",
+ "๐ Raw response length: 5235 characters\n",
+ "๐ Parsed JSON type: \n",
+ "๐ Found data in 'responses': 10 rows\n",
+ "โ
Successfully generated 10 survey responses\n",
+ "\n",
+ "๐ฆ Processing batch: 10 rows (remaining: 80)\n",
+ "๐ Generating 10 survey responses with LLM...\n",
+ "๐ Using max_tokens: 3500 (estimated: 3500)\n",
+ "๐ Raw response length: 5232 characters\n",
+ "๐ Parsed JSON type: \n",
+ "๐ Found data in 'responses': 10 rows\n",
+ "โ
Successfully generated 10 survey responses\n",
+ "\n",
+ "๐ฆ Processing batch: 10 rows (remaining: 70)\n",
+ "๐ Generating 10 survey responses with LLM...\n",
+ "๐ Using max_tokens: 3500 (estimated: 3500)\n",
+ "๐ Raw response length: 5239 characters\n",
+ "๐ Parsed JSON type: \n",
+ "๐ Found data in 'responses': 10 rows\n",
+ "โ
Successfully generated 10 survey responses\n",
+ "\n",
+ "๐ฆ Processing batch: 10 rows (remaining: 60)\n",
+ "๐ Generating 10 survey responses with LLM...\n",
+ "๐ Using max_tokens: 3500 (estimated: 3500)\n",
+ "๐ Raw response length: 5238 characters\n",
+ "๐ Parsed JSON type: \n",
+ "๐ Found data in 'responses': 10 rows\n",
+ "โ
Successfully generated 10 survey responses\n",
+ "\n",
+ "๐ฆ Processing batch: 10 rows (remaining: 50)\n",
+ "๐ Generating 10 survey responses with LLM...\n",
+ "๐ Using max_tokens: 3500 (estimated: 3500)\n",
+ "๐ Raw response length: 5236 characters\n",
+ "๐ Parsed JSON type: \n",
+ "๐ Found data in 'responses': 10 rows\n",
+ "โ
Successfully generated 10 survey responses\n",
+ "\n",
+ "๐ฆ Processing batch: 10 rows (remaining: 40)\n",
+ "๐ Generating 10 survey responses with LLM...\n",
+ "๐ Using max_tokens: 3500 (estimated: 3500)\n",
+ "๐ Raw response length: 5229 characters\n",
+ "๐ Parsed JSON type: \n",
+ "๐ Found data in 'responses': 10 rows\n",
+ "โ
Successfully generated 10 survey responses\n",
+ "\n",
+ "๐ฆ Processing batch: 10 rows (remaining: 30)\n",
+ "๐ Generating 10 survey responses with LLM...\n",
+ "๐ Using max_tokens: 3500 (estimated: 3500)\n",
+ "๐ Raw response length: 5244 characters\n",
+ "๐ Parsed JSON type: \n",
+ "๐ Found data in 'responses': 10 rows\n",
+ "โ
Successfully generated 10 survey responses\n",
+ "\n",
+ "๐ฆ Processing batch: 10 rows (remaining: 20)\n",
+ "๐ Generating 10 survey responses with LLM...\n",
+ "๐ Using max_tokens: 3500 (estimated: 3500)\n",
+ "๐ Raw response length: 5234 characters\n",
+ "๐ Parsed JSON type: \n",
+ "๐ Found data in 'responses': 10 rows\n",
+ "โ
Successfully generated 10 survey responses\n",
+ "\n",
+ "๐ฆ Processing batch: 10 rows (remaining: 10)\n",
+ "๐ Generating 10 survey responses with LLM...\n",
+ "๐ Using max_tokens: 3500 (estimated: 3500)\n",
+ "๐ Raw response length: 5238 characters\n",
+ "๐ Parsed JSON type: \n",
+ "๐ Found data in 'responses': 10 rows\n",
+ "โ
Successfully generated 10 survey responses\n",
+ "โ
Generated total: 100 survey responses\n",
+ "\n",
+ "๐ Large dataset shape: (100, 18)\n",
+ "\n",
+ "๐ Summary statistics:\n",
+ " respondent_id age completion_seconds q_quality q_value \\\n",
+ "count 100.000000 100.000000 100.000000 100.000000 100.000000 \n",
+ "mean 33513.700000 34.070000 588.525000 3.740000 3.910000 \n",
+ "std 29233.800863 7.835757 230.530212 1.001211 0.995901 \n",
+ "min 10001.000000 22.000000 120.500000 2.000000 2.000000 \n",
+ "25% 10009.000000 28.000000 420.375000 3.000000 3.000000 \n",
+ "50% 15122.500000 33.000000 600.000000 4.000000 4.000000 \n",
+ "75% 55955.750000 39.250000 720.000000 5.000000 5.000000 \n",
+ "max 98765.000000 50.000000 1500.000000 5.000000 5.000000 \n",
+ "\n",
+ " q_ease q_support nps \n",
+ "count 100.000000 100.000000 100.000000 \n",
+ "mean 3.900000 3.910000 6.990000 \n",
+ "std 0.937437 0.985706 2.333312 \n",
+ "min 2.000000 2.000000 2.000000 \n",
+ "25% 3.000000 3.000000 5.000000 \n",
+ "50% 4.000000 4.000000 7.000000 \n",
+ "75% 5.000000 5.000000 9.000000 \n",
+ "max 5.000000 5.000000 10.000000 \n",
+ "๐พ Saved: data\\survey_llm_fixed_20251023T005139Z.csv\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "C:\\Users\\Joshua\\AppData\\Local\\Temp\\ipykernel_27572\\2716383900.py:12: DeprecationWarning: datetime.datetime.utcnow() is deprecated and scheduled for removal in a future version. Use timezone-aware objects to represent datetimes in UTC: datetime.datetime.now(datetime.UTC).\n",
+ " ts = datetime.utcnow().strftime(\"%Y%m%dT%H%M%SZ\")\n"
+ ]
+ }
+ ],
+ "source": [
+ "#Test larger dataset generation \n",
+ "print(\"๐ Testing larger dataset generation...\")\n",
+ "large_df = fixed_generate_llm(CFG, total_rows=100, batch_size=25)\n",
+ "if not large_df.empty:\n",
+ " print(f\"\\n๐ Large dataset shape: {large_df.shape}\")\n",
+ " print(f\"\\n๐ Summary statistics:\")\n",
+ " print(large_df.describe())\n",
+ " \n",
+ " # Save the results\n",
+ " from pathlib import Path\n",
+ " out = Path(\"data\"); out.mkdir(exist_ok=True)\n",
+ " ts = datetime.utcnow().strftime(\"%Y%m%dT%H%M%SZ\")\n",
+ " csv_path = out / f\"survey_llm_fixed_{ts}.csv\"\n",
+ " large_df.to_csv(csv_path, index=False)\n",
+ " print(f\"๐พ Saved: {csv_path}\")\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "6029d3e2",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "LLM available: True\n"
+ ]
+ }
+ ],
+ "source": [
+ "\n",
+ "def build_json_schema(CFG):\n",
+ " schema = {'type':'array','items':{'type':'object','properties':{},'required':[]}}\n",
+ " props = schema['items']['properties']; req = schema['items']['required']\n",
+ " for f in CFG['fields']:\n",
+ " name, t = f['name'], f['type']\n",
+ " req.append(name)\n",
+ " if t in ('int','float'): props[name] = {'type':'number' if t=='float' else 'integer'}\n",
+ " elif t == 'enum': props[name] = {'type':'string','enum': f['values']}\n",
+ " elif t in ('uuid4','datetime'): props[name] = {'type':'string'}\n",
+ " elif t == 'bool': props[name] = {'type':'boolean'}\n",
+ " else: props[name] = {'type':'string'}\n",
+ " return schema\n",
+ "\n",
+ "PROMPT_PREAMBLE = (\n",
+ " \"You are a data generator. Return ONLY JSON. \"\n",
+ " \"Respond as a JSON object with key 'rows' whose value is an array of exactly N objects. \"\n",
+ " \"No prose, no code fences, no trailing commas.\"\n",
+ ")\n",
+ "\n",
+ "def render_prompt(CFG, n_rows=100):\n",
+ " minimal_cfg = {'fields': []}\n",
+ " for f in CFG['fields']:\n",
+ " base = {k: f[k] for k in ['name','type'] if k in f}\n",
+ " if 'min' in f and 'max' in f: base.update({'min': f['min'], 'max': f['max']})\n",
+ " if 'values' in f: base.update({'values': f['values']})\n",
+ " if 'fmt' in f: base.update({'fmt': f['fmt']})\n",
+ " minimal_cfg['fields'].append(base)\n",
+ " return {\n",
+ " 'preamble': PROMPT_PREAMBLE,\n",
+ " 'n_rows': n_rows,\n",
+ " 'schema': build_json_schema(CFG),\n",
+ " 'constraints': minimal_cfg,\n",
+ " 'instruction': f\"Return ONLY this structure: {{'rows': [ ... exactly {n_rows} objects ... ]}}\"\n",
+ " }\n",
+ "\n",
+ "def parse_llm_json_to_df(raw: str) -> pd.DataFrame:\n",
+ " try:\n",
+ " obj = json.loads(raw)\n",
+ " if isinstance(obj, dict) and isinstance(obj.get('rows'), list):\n",
+ " return pd.DataFrame(obj['rows'])\n",
+ " except Exception:\n",
+ " pass\n",
+ " data = extract_strict_json(raw)\n",
+ " return pd.DataFrame(data)\n",
+ "\n",
+ "USE_LLM = bool(os.getenv('OPENAI_API_KEY'))\n",
+ "print('LLM available:', USE_LLM)\n",
+ "\n",
+ "def llm_generate_batch(CFG, n_rows=50):\n",
+ " if USE_LLM:\n",
+ " try:\n",
+ " from openai import OpenAI\n",
+ " client = OpenAI()\n",
+ " prompt = json.dumps(render_prompt(CFG, n_rows))\n",
+ " resp = client.chat.completions.create(\n",
+ " model='gpt-4o-mini',\n",
+ " response_format={'type': 'json_object'},\n",
+ " messages=[\n",
+ " {'role':'system','content':'You output strict JSON only.'},\n",
+ " {'role':'user','content': prompt}\n",
+ " ],\n",
+ " temperature=0.2,\n",
+ " max_tokens=8192,\n",
+ " )\n",
+ " raw = resp.choices[0].message.content\n",
+ " try:\n",
+ " return parse_llm_json_to_df(raw)\n",
+ " except Exception:\n",
+ " stricter = (\n",
+ " prompt\n",
+ " + \"\\nReturn ONLY a JSON object structured as: \"\n",
+ " + \"{\\\"rows\\\": [ ... exactly N objects ... ]}. \"\n",
+ " + \"No prose, no explanations.\"\n",
+ " )\n",
+ " resp2 = client.chat.completions.create(\n",
+ " model='gpt-4o-mini',\n",
+ " response_format={'type': 'json_object'},\n",
+ " messages=[\n",
+ " {'role':'system','content':'You output strict JSON only.'},\n",
+ " {'role':'user','content': stricter}\n",
+ " ],\n",
+ " temperature=0.2,\n",
+ " max_tokens=8192,\n",
+ " )\n",
+ " raw2 = resp2.choices[0].message.content\n",
+ " return parse_llm_json_to_df(raw2)\n",
+ " except Exception as e:\n",
+ " print('LLM error, fallback to rule-based mock:', e)\n",
+ " tmp = dict(CFG); tmp['rows'] = n_rows\n",
+ " return generate_rule_based(tmp)\n",
+ "\n",
+ "def generate_llm(CFG, total_rows=200, batch_size=50):\n",
+ " dfs = []; remaining = total_rows\n",
+ " while remaining > 0:\n",
+ " b = min(batch_size, remaining)\n",
+ " dfs.append(llm_generate_batch(CFG, n_rows=b))\n",
+ " remaining -= b\n",
+ " time.sleep(0.2)\n",
+ " return pd.concat(dfs, ignore_index=True)\n",
+ "\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 9,
+ "id": "2e759087",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "LLM error, fallback to rule-based mock: No JSON array found in model output.\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " response_id | \n",
+ " respondent_id | \n",
+ " submitted_at | \n",
+ " country | \n",
+ " language | \n",
+ " device | \n",
+ " age | \n",
+ " gender | \n",
+ " education | \n",
+ " income_band | \n",
+ " completion_seconds | \n",
+ " attention_passed | \n",
+ " q_quality | \n",
+ " q_value | \n",
+ " q_ease | \n",
+ " q_support | \n",
+ " nps | \n",
+ " is_detractor | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " | 0 | \n",
+ " 9e7811bd-27ee-4b7c-9b7a-c98441e337f0 | \n",
+ " 40160 | \n",
+ " 2024-08-18 19:10:06 | \n",
+ " KE | \n",
+ " sw | \n",
+ " web | \n",
+ " 28 | \n",
+ " male | \n",
+ " secondary | \n",
+ " lower_mid | \n",
+ " 1800.000000 | \n",
+ " True | \n",
+ " 4 | \n",
+ " 3 | \n",
+ " 3 | \n",
+ " 3 | \n",
+ " 4 | \n",
+ " True | \n",
+ "
\n",
+ " \n",
+ " | 1 | \n",
+ " 85ec8b90-5468-4880-8309-e325da14d877 | \n",
+ " 55381 | \n",
+ " 2025-01-24 12:21:13 | \n",
+ " TZ | \n",
+ " sw | \n",
+ " ios | \n",
+ " 23 | \n",
+ " female | \n",
+ " bachelor | \n",
+ " high | \n",
+ " 431.412783 | \n",
+ " True | \n",
+ " 3 | \n",
+ " 2 | \n",
+ " 3 | \n",
+ " 4 | \n",
+ " 4 | \n",
+ " False | \n",
+ "
\n",
+ " \n",
+ " | 2 | \n",
+ " 498dff10-040f-4206-8170-dfce0d5a69f0 | \n",
+ " 48338 | \n",
+ " 2025-07-15 22:21:54 | \n",
+ " TZ | \n",
+ " en | \n",
+ " ios | \n",
+ " 49 | \n",
+ " male | \n",
+ " bachelor | \n",
+ " low | \n",
+ " 1800.000000 | \n",
+ " True | \n",
+ " 2 | \n",
+ " 3 | \n",
+ " 3 | \n",
+ " 1 | \n",
+ " 3 | \n",
+ " False | \n",
+ "
\n",
+ " \n",
+ " | 3 | \n",
+ " ddf11d94-5d6e-4322-9811-4e763f5ed46b | \n",
+ " 59925 | \n",
+ " 2025-01-27 00:16:57 | \n",
+ " KE | \n",
+ " en | \n",
+ " web | \n",
+ " 22 | \n",
+ " male | \n",
+ " bachelor | \n",
+ " upper_mid | \n",
+ " 656.050991 | \n",
+ " True | \n",
+ " 4 | \n",
+ " 4 | \n",
+ " 1 | \n",
+ " 3 | \n",
+ " 5 | \n",
+ " False | \n",
+ "
\n",
+ " \n",
+ " | 4 | \n",
+ " 2ef22a0c-fd13-4798-9276-f43831b8f7bc | \n",
+ " 68993 | \n",
+ " 2024-08-19 04:21:49 | \n",
+ " KE | \n",
+ " en | \n",
+ " android | \n",
+ " 40 | \n",
+ " male | \n",
+ " secondary | \n",
+ " lower_mid | \n",
+ " 1553.938944 | \n",
+ " True | \n",
+ " 2 | \n",
+ " 2 | \n",
+ " 5 | \n",
+ " 1 | \n",
+ " 5 | \n",
+ " False | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
"
+ ],
+ "text/plain": [
+ " response_id respondent_id submitted_at \\\n",
+ "0 9e7811bd-27ee-4b7c-9b7a-c98441e337f0 40160 2024-08-18 19:10:06 \n",
+ "1 85ec8b90-5468-4880-8309-e325da14d877 55381 2025-01-24 12:21:13 \n",
+ "2 498dff10-040f-4206-8170-dfce0d5a69f0 48338 2025-07-15 22:21:54 \n",
+ "3 ddf11d94-5d6e-4322-9811-4e763f5ed46b 59925 2025-01-27 00:16:57 \n",
+ "4 2ef22a0c-fd13-4798-9276-f43831b8f7bc 68993 2024-08-19 04:21:49 \n",
+ "\n",
+ " country language device age gender education income_band \\\n",
+ "0 KE sw web 28 male secondary lower_mid \n",
+ "1 TZ sw ios 23 female bachelor high \n",
+ "2 TZ en ios 49 male bachelor low \n",
+ "3 KE en web 22 male bachelor upper_mid \n",
+ "4 KE en android 40 male secondary lower_mid \n",
+ "\n",
+ " completion_seconds attention_passed q_quality q_value q_ease \\\n",
+ "0 1800.000000 True 4 3 3 \n",
+ "1 431.412783 True 3 2 3 \n",
+ "2 1800.000000 True 2 3 3 \n",
+ "3 656.050991 True 4 4 1 \n",
+ "4 1553.938944 True 2 2 5 \n",
+ "\n",
+ " q_support nps is_detractor \n",
+ "0 3 4 True \n",
+ "1 4 4 False \n",
+ "2 1 3 False \n",
+ "3 3 5 False \n",
+ "4 1 5 False "
+ ]
+ },
+ "execution_count": 9,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "df_llm = generate_llm(CFG, total_rows=100, batch_size=50)\n",
+ "df_llm.head()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 46,
+ "id": "6d4908ad",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "๐งช Testing improved LLM generation with adaptive batching...\n",
+ "\n",
+ "๐ฆ Testing small batch (10 rows)...\n",
+ "๐ Generating 10 survey responses with LLM...\n",
+ "๐ Using max_tokens: 3500 (estimated: 3500)\n",
+ "๐ Raw response length: 5233 characters\n",
+ "๐ Parsed JSON type: \n",
+ "๐ Found data in 'responses': 10 rows\n",
+ "โ
Successfully generated 10 survey responses\n",
+ "โ
Small batch result: 10 rows\n",
+ "\n",
+ "๐ฆ Testing medium dataset (30 rows) with adaptive batching...\n",
+ "๐ Generating 30 survey responses with adaptive batching\n",
+ "๐ Using optimal batch size: 15\n",
+ "\n",
+ "๐ฆ Processing batch: 15 rows (remaining: 30)\n",
+ "๐ Generating 15 survey responses with LLM...\n",
+ "๐ Using max_tokens: 5000 (estimated: 5000)\n",
+ "๐ Raw response length: 7839 characters\n",
+ "๐ Parsed JSON type: \n",
+ "๐ Found data in 'responses': 15 rows\n",
+ "โ
Successfully generated 15 survey responses\n",
+ "\n",
+ "๐ฆ Processing batch: 15 rows (remaining: 15)\n",
+ "๐ Generating 15 survey responses with LLM...\n",
+ "๐ Using max_tokens: 5000 (estimated: 5000)\n",
+ "๐ Raw response length: 7841 characters\n",
+ "๐ Parsed JSON type: \n",
+ "๐ Found data in 'responses': 15 rows\n",
+ "โ
Successfully generated 15 survey responses\n",
+ "โ
Generated total: 30 survey responses\n",
+ "โ
Medium dataset result: 30 rows\n",
+ "\n",
+ "๐ Dataset shape: (30, 18)\n",
+ "\n",
+ "๐ First few rows:\n",
+ " response_id respondent_id submitted_at \\\n",
+ "0 d1e5c4a3-4b1f-4f6b-8f9e-9f1e1f2e3d4c 10001 2023-10-01 14:30:00 \n",
+ "1 c2b1d4a6-7f8e-4c5c-9d8f-1e2c3b4a5e6f 10002 2023-10-01 15:00:00 \n",
+ "2 e3f2c5b7-8a2d-4c8e-9f1b-2c3d4e5f6a7b 10003 2023-10-01 15:30:00 \n",
+ "3 f4a5b6c8-9d3e-4b1f-9f2c-3d4e5f6a7b8c 10004 2023-10-01 16:00:00 \n",
+ "4 g5b6c7d9-0e4f-4b2a-8f3d-4e5f6a7b8c9d 10005 2023-10-01 16:30:00 \n",
+ "\n",
+ " country language device age gender education income_band \\\n",
+ "0 KE en android 28 female bachelor upper_mid \n",
+ "1 UG sw web 35 male diploma lower_mid \n",
+ "2 TZ en ios 42 nonbinary postgraduate high \n",
+ "3 RW sw web 29 female secondary upper_mid \n",
+ "4 NG en android 50 male bachelor high \n",
+ "\n",
+ " completion_seconds attention_passed q_quality q_value q_ease \\\n",
+ "0 450.0 True 5 4 5 \n",
+ "1 600.0 True 3 2 4 \n",
+ "2 720.0 True 4 5 4 \n",
+ "3 300.0 True 3 3 3 \n",
+ "4 540.0 True 5 5 5 \n",
+ "\n",
+ " q_support nps is_detractor \n",
+ "0 4 9 False \n",
+ "1 3 5 False \n",
+ "2 5 10 False \n",
+ "3 4 6 False \n",
+ "4 5 10 False \n",
+ "๐พ Saved: data\\survey_adaptive_batch_20251023T005927Z.csv\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "C:\\Users\\Joshua\\AppData\\Local\\Temp\\ipykernel_27572\\1770033334.py:22: DeprecationWarning: datetime.datetime.utcnow() is deprecated and scheduled for removal in a future version. Use timezone-aware objects to represent datetimes in UTC: datetime.datetime.now(datetime.UTC).\n",
+ " ts = datetime.utcnow().strftime(\"%Y%m%dT%H%M%SZ\")\n"
+ ]
+ }
+ ],
+ "source": [
+ "# Test the improved LLM generation with adaptive batching\n",
+ "print(\"๐งช Testing improved LLM generation with adaptive batching...\")\n",
+ "\n",
+ "# Test with smaller dataset first\n",
+ "print(\"\\n๐ฆ Testing small batch (10 rows)...\")\n",
+ "small_df = fixed_llm_generate_batch(CFG, 10)\n",
+ "print(f\"โ
Small batch result: {len(small_df)} rows\")\n",
+ "\n",
+ "# Test with medium dataset using adaptive batching\n",
+ "print(\"\\n๐ฆ Testing medium dataset (30 rows) with adaptive batching...\")\n",
+ "medium_df = fixed_generate_llm(CFG, total_rows=30, batch_size=15)\n",
+ "print(f\"โ
Medium dataset result: {len(medium_df)} rows\")\n",
+ "\n",
+ "if not medium_df.empty:\n",
+ " print(f\"\\n๐ Dataset shape: {medium_df.shape}\")\n",
+ " print(f\"\\n๐ First few rows:\")\n",
+ " print(medium_df.head())\n",
+ " \n",
+ " # Save the results\n",
+ " from pathlib import Path\n",
+ " out = Path(\"data\"); out.mkdir(exist_ok=True)\n",
+ " ts = datetime.utcnow().strftime(\"%Y%m%dT%H%M%SZ\")\n",
+ " csv_path = out / f\"survey_adaptive_batch_{ts}.csv\"\n",
+ " medium_df.to_csv(csv_path, index=False)\n",
+ " print(f\"๐พ Saved: {csv_path}\")\n",
+ "else:\n",
+ " print(\"โ Medium dataset generation failed\")\n"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": ".venv",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.12.12"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}