diff --git a/week3/community-contributions/week3_Exercise_survey_Dataset_Generation.ipynb b/week3/community-contributions/week3_Exercise_survey_Dataset_Generation.ipynb new file mode 100644 index 0000000..a11fa96 --- /dev/null +++ b/week3/community-contributions/week3_Exercise_survey_Dataset_Generation.ipynb @@ -0,0 +1,1897 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "a8dbb4e8", + "metadata": {}, + "source": [ + "# ๐Ÿงช Survey Synthetic Dataset Generator โ€” Week 3 Task" + ] + }, + { + "cell_type": "code", + "execution_count": 34, + "id": "8d86f629", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "โœ… Base libraries ready. Pandera available: True\n" + ] + } + ], + "source": [ + "\n", + "import os, re, json, time, uuid, math, random\n", + "from datetime import datetime, timedelta\n", + "from typing import List, Dict, Any\n", + "import numpy as np, pandas as pd\n", + "import pandera.pandas as pa\n", + "random.seed(7); np.random.seed(7)\n", + "print(\"โœ… Base libraries ready. Pandera available:\", pa is not None)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "id": "f196ae73", + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "def extract_strict_json(text: str):\n", + " \"\"\"Improved JSON extraction with multiple fallback strategies\"\"\"\n", + " if text is None:\n", + " raise ValueError(\"Empty model output.\")\n", + " \n", + " t = text.strip()\n", + " \n", + " # Strategy 1: Direct JSON parsing\n", + " try:\n", + " obj = json.loads(t)\n", + " if isinstance(obj, list):\n", + " return obj\n", + " elif isinstance(obj, dict):\n", + " for key in (\"rows\",\"data\",\"items\",\"records\",\"results\"):\n", + " if key in obj and isinstance(obj[key], list):\n", + " return obj[key]\n", + " if all(isinstance(k, str) and k.isdigit() for k in obj.keys()):\n", + " return [obj[k] for k in sorted(obj.keys(), key=int)]\n", + " except json.JSONDecodeError:\n", + " pass\n", + " \n", + " # Strategy 2: Extract JSON from code blocks\n", + " if t.startswith(\"```\"):\n", + " t = re.sub(r\"^```(?:json)?\\s*|\\s*```$\", \"\", t, flags=re.IGNORECASE|re.MULTILINE).strip()\n", + " \n", + " # Strategy 3: Find JSON array in text\n", + " start, end = t.find('['), t.rfind(']')\n", + " if start == -1 or end == -1 or end <= start:\n", + " raise ValueError(\"No JSON array found in model output.\")\n", + " \n", + " t = t[start:end+1]\n", + " \n", + " # Strategy 4: Fix common JSON issues\n", + " t = re.sub(r\",\\s*([\\]}])\", r\"\\1\", t) # Remove trailing commas\n", + " t = re.sub(r\"\\bNaN\\b|\\bInfinity\\b|\\b-Infinity\\b\", \"null\", t) # Replace NaN/Infinity\n", + " t = t.replace(\"\\u00a0\", \" \").replace(\"\\u200b\", \"\") # Remove invisible characters\n", + " \n", + " try:\n", + " return json.loads(t)\n", + " except json.JSONDecodeError as e:\n", + " raise ValueError(f\"Could not parse JSON: {str(e)}. Text: {t[:200]}...\")\n" + ] + }, + { + "cell_type": "markdown", + "id": "3670fa0d", + "metadata": {}, + "source": [ + "## 1) Configuration" + ] + }, + { + "cell_type": "code", + "execution_count": 36, + "id": "d16bd03a", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Loaded config for 800 rows and 18 fields.\n" + ] + } + ], + "source": [ + "\n", + "CFG = {\n", + " \"rows\": 800,\n", + " \"datetime_range\": {\"start\": \"2024-01-01\", \"end\": \"2025-10-01\", \"fmt\": \"%Y-%m-%d %H:%M:%S\"},\n", + " \"fields\": [\n", + " {\"name\": \"response_id\", \"type\": \"uuid4\"},\n", + " {\"name\": \"respondent_id\", \"type\": \"int\", \"min\": 10000, \"max\": 99999},\n", + " {\"name\": \"submitted_at\", \"type\": \"datetime\"},\n", + " {\"name\": \"country\", \"type\": \"enum\", \"values\": [\"KE\",\"UG\",\"TZ\",\"RW\",\"NG\",\"ZA\"], \"probs\": [0.50,0.10,0.12,0.05,0.15,0.08]},\n", + " {\"name\": \"language\", \"type\": \"enum\", \"values\": [\"en\",\"sw\"], \"probs\": [0.85,0.15]},\n", + " {\"name\": \"device\", \"type\": \"enum\", \"values\": [\"android\",\"ios\",\"web\"], \"probs\": [0.60,0.25,0.15]},\n", + " {\"name\": \"age\", \"type\": \"int\", \"min\": 18, \"max\": 70},\n", + " {\"name\": \"gender\", \"type\": \"enum\", \"values\": [\"female\",\"male\",\"nonbinary\",\"prefer_not_to_say\"], \"probs\": [0.49,0.49,0.01,0.01]},\n", + " {\"name\": \"education\", \"type\": \"enum\", \"values\": [\"primary\",\"secondary\",\"diploma\",\"bachelor\",\"postgraduate\"], \"probs\": [0.08,0.32,0.18,0.30,0.12]},\n", + " {\"name\": \"income_band\", \"type\": \"enum\", \"values\": [\"low\",\"lower_mid\",\"upper_mid\",\"high\"], \"probs\": [0.28,0.42,0.23,0.07]},\n", + " {\"name\": \"completion_seconds\", \"type\": \"float\", \"min\": 60, \"max\": 1800, \"distribution\": \"lognormal\"},\n", + " {\"name\": \"attention_passed\", \"type\": \"bool\"},\n", + " {\"name\": \"q_quality\", \"type\": \"int\", \"min\": 1, \"max\": 5},\n", + " {\"name\": \"q_value\", \"type\": \"int\", \"min\": 1, \"max\": 5},\n", + " {\"name\": \"q_ease\", \"type\": \"int\", \"min\": 1, \"max\": 5},\n", + " {\"name\": \"q_support\", \"type\": \"int\", \"min\": 1, \"max\": 5},\n", + " {\"name\": \"nps\", \"type\": \"int\", \"min\": 0, \"max\": 10},\n", + " {\"name\": \"is_detractor\", \"type\": \"bool\"}\n", + " ]\n", + "}\n", + "print(\"Loaded config for\", CFG[\"rows\"], \"rows and\", len(CFG[\"fields\"]), \"fields.\")\n" + ] + }, + { + "cell_type": "markdown", + "id": "7da1f429", + "metadata": {}, + "source": [ + "## 2) Helpers" + ] + }, + { + "cell_type": "code", + "execution_count": 37, + "id": "d2f5fdff", + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "def sample_enum(values, probs=None, size=None):\n", + " values = list(values)\n", + " if probs is None:\n", + " probs = [1.0 / len(values)] * len(values)\n", + " return np.random.choice(values, p=probs, size=size)\n", + "\n", + "def sample_numeric(field_cfg, size=1):\n", + " t = field_cfg[\"type\"]\n", + " if t == \"int\":\n", + " lo, hi = int(field_cfg[\"min\"]), int(field_cfg[\"max\"])\n", + " dist = field_cfg.get(\"distribution\", \"uniform\")\n", + " if dist == \"uniform\":\n", + " return np.random.randint(lo, hi + 1, size=size)\n", + " elif dist == \"normal\":\n", + " mu = (lo + hi) / 2.0\n", + " sigma = (hi - lo) / 6.0\n", + " out = np.random.normal(mu, sigma, size=size)\n", + " return np.clip(out, lo, hi).astype(int)\n", + " else:\n", + " return np.random.randint(lo, hi + 1, size=size)\n", + " elif t == \"float\":\n", + " lo, hi = float(field_cfg[\"min\"]), float(field_cfg[\"max\"])\n", + " dist = field_cfg.get(\"distribution\", \"uniform\")\n", + " if dist == \"uniform\":\n", + " return np.random.uniform(lo, hi, size=size)\n", + " elif dist == \"normal\":\n", + " mu = (lo + hi) / 2.0\n", + " sigma = (hi - lo) / 6.0\n", + " return np.clip(np.random.normal(mu, sigma, size=size), lo, hi)\n", + " elif dist == \"lognormal\":\n", + " mu = math.log(max(1e-3, (lo + hi) / 2.0))\n", + " sigma = 0.75\n", + " out = np.random.lognormal(mu, sigma, size=size)\n", + " return np.clip(out, lo, hi)\n", + " else:\n", + " return np.random.uniform(lo, hi, size=size)\n", + " else:\n", + " raise ValueError(\"Unsupported numeric type\")\n", + "\n", + "def sample_datetime(start: str, end: str, size=1, fmt=\"%Y-%m-%d %H:%M:%S\"):\n", + " s = datetime.fromisoformat(start)\n", + " e = datetime.fromisoformat(end)\n", + " total = int((e - s).total_seconds())\n", + " r = np.random.randint(0, total, size=size)\n", + " return [(s + timedelta(seconds=int(x))).strftime(fmt) for x in r]\n" + ] + }, + { + "cell_type": "markdown", + "id": "5f24111a", + "metadata": {}, + "source": [ + "## 3) Rule-based Generator" + ] + }, + { + "cell_type": "code", + "execution_count": 38, + "id": "cd61330d", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
response_idrespondent_idsubmitted_atcountrylanguagedeviceagegendereducationincome_bandcompletion_secondsattention_passedq_qualityq_valueq_easeq_supportnpsis_detractor
0f099c1b6-a4ae-4fb0-ba98-89a81008c424716152024-04-13 19:02:44ZAenweb47malesecondarylow897.995012True53134True
1f2e20ad1-1ed1-4e33-8beb-5dd0ba23715b685642024-03-05 23:30:30KEenandroid67femalebachelorlower_mid935.607966True15235False
2a9345f69-be75-46b9-8cd3-a276ce0a66bd596892024-11-10 03:38:07RWswandroid23malebachelorlow1431.517701True52557False
3b4fa8625-d153-4465-ad73-1c4a48eed2f1207422024-11-19 17:40:58KEenios68femalesecondaryupper_mid448.519416True555310False
4e0ad4bbc-b576-4913-8786-302f06b5e9f7634592024-07-28 04:23:37KEenios34malesecondarylow1179.970734True31335False
\n", + "
" + ], + "text/plain": [ + " response_id respondent_id submitted_at \\\n", + "0 f099c1b6-a4ae-4fb0-ba98-89a81008c424 71615 2024-04-13 19:02:44 \n", + "1 f2e20ad1-1ed1-4e33-8beb-5dd0ba23715b 68564 2024-03-05 23:30:30 \n", + "2 a9345f69-be75-46b9-8cd3-a276ce0a66bd 59689 2024-11-10 03:38:07 \n", + "3 b4fa8625-d153-4465-ad73-1c4a48eed2f1 20742 2024-11-19 17:40:58 \n", + "4 e0ad4bbc-b576-4913-8786-302f06b5e9f7 63459 2024-07-28 04:23:37 \n", + "\n", + " country language device age gender education income_band \\\n", + "0 ZA en web 47 male secondary low \n", + "1 KE en android 67 female bachelor lower_mid \n", + "2 RW sw android 23 male bachelor low \n", + "3 KE en ios 68 female secondary upper_mid \n", + "4 KE en ios 34 male secondary low \n", + "\n", + " completion_seconds attention_passed q_quality q_value q_ease \\\n", + "0 897.995012 True 5 3 1 \n", + "1 935.607966 True 1 5 2 \n", + "2 1431.517701 True 5 2 5 \n", + "3 448.519416 True 5 5 5 \n", + "4 1179.970734 True 3 1 3 \n", + "\n", + " q_support nps is_detractor \n", + "0 3 4 True \n", + "1 3 5 False \n", + "2 5 7 False \n", + "3 3 10 False \n", + "4 3 5 False " + ] + }, + "execution_count": 38, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "\n", + "def generate_rule_based(CFG: Dict[str, Any]) -> pd.DataFrame:\n", + " n = CFG[\"rows\"]\n", + " dt_cfg = CFG.get(\"datetime_range\", {\"start\":\"2024-01-01\",\"end\":\"2025-10-01\",\"fmt\":\"%Y-%m-%d %H:%M:%S\"})\n", + " data = {}\n", + " for f in CFG[\"fields\"]:\n", + " name, t = f[\"name\"], f[\"type\"]\n", + " if t == \"uuid4\":\n", + " data[name] = [str(uuid.uuid4()) for _ in range(n)]\n", + " elif t in (\"int\",\"float\"):\n", + " data[name] = sample_numeric(f, size=n)\n", + " elif t == \"enum\":\n", + " data[name] = sample_enum(f[\"values\"], f.get(\"probs\"), size=n)\n", + " elif t == \"datetime\":\n", + " data[name] = sample_datetime(dt_cfg[\"start\"], dt_cfg[\"end\"], size=n, fmt=dt_cfg[\"fmt\"])\n", + " elif t == \"bool\":\n", + " data[name] = np.random.rand(n) < 0.9 # 90% True\n", + " else:\n", + " data[name] = [None]*n\n", + " df = pd.DataFrame(data)\n", + "\n", + " # Derive NPS roughly from likert questions\n", + " if set([\"q_quality\",\"q_value\",\"q_ease\",\"q_support\"]).issubset(df.columns):\n", + " likert_avg = df[[\"q_quality\",\"q_value\",\"q_ease\",\"q_support\"]].mean(axis=1)\n", + " df[\"nps\"] = np.clip(np.round((likert_avg - 1.0) * (10.0/4.0) + np.random.normal(0, 1.2, size=n)), 0, 10).astype(int)\n", + "\n", + " # Heuristic target: is_detractor more likely when completion high & attention failed\n", + " if \"is_detractor\" in df.columns:\n", + " base = 0.25\n", + " comp = df.get(\"completion_seconds\", pd.Series(np.zeros(n)))\n", + " attn = pd.Series(df.get(\"attention_passed\", np.ones(n))).astype(bool)\n", + " boost = (comp > 900).astype(int) + (~attn).astype(int)\n", + " p = np.clip(base + 0.15*boost, 0.01, 0.95)\n", + " df[\"is_detractor\"] = np.random.rand(n) < p\n", + "\n", + " return df\n", + "\n", + "df_rule = generate_rule_based(CFG)\n", + "df_rule.head()\n" + ] + }, + { + "cell_type": "markdown", + "id": "dd9eff20", + "metadata": {}, + "source": [ + "## 4) Validation (Pandera optional)" + ] + }, + { + "cell_type": "code", + "execution_count": 39, + "id": "9a4ef86a", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Validation error: {\n", + " \"SCHEMA\": {\n", + " \"WRONG_DATATYPE\": [\n", + " {\n", + " \"schema\": null,\n", + " \"column\": \"respondent_id\",\n", + " \"check\": \"dtype('int64')\",\n", + " \"error\": \"expected series 'respondent_id' to have type int64, got int32\"\n", + " },\n", + " {\n", + " \"schema\": null,\n", + " \"column\": \"age\",\n", + " \"check\": \"dtype('int64')\",\n", + " \"error\": \"expected series 'age' to have type int64, got int32\"\n", + " },\n", + " {\n", + " \"schema\": null,\n", + " \"column\": \"q_quality\",\n", + " \"check\": \"dtype('int64')\",\n", + " \"error\": \"expected series 'q_quality' to have type int64, got int32\"\n", + " },\n", + " {\n", + " \"schema\": null,\n", + " \"column\": \"q_value\",\n", + " \"check\": \"dtype('int64')\",\n", + " \"error\": \"expected series 'q_value' to have type int64, got int32\"\n", + " },\n", + " {\n", + " \"schema\": null,\n", + " \"column\": \"q_ease\",\n", + " \"check\": \"dtype('int64')\",\n", + " \"error\": \"expected series 'q_ease' to have type int64, got int32\"\n", + " },\n", + " {\n", + " \"schema\": null,\n", + " \"column\": \"q_support\",\n", + " \"check\": \"dtype('int64')\",\n", + " \"error\": \"expected series 'q_support' to have type int64, got int32\"\n", + " }\n", + " ]\n", + " }\n", + "}\n", + "{'engine': 'pandera', 'valid_rows': 800, 'invalid_rows': 0, 'notes': 'Non-strict mode.'}\n" + ] + }, + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
response_idrespondent_idsubmitted_atcountrylanguagedeviceagegendereducationincome_bandcompletion_secondsattention_passedq_qualityq_valueq_easeq_supportnpsis_detractor
0f099c1b6-a4ae-4fb0-ba98-89a81008c424716152024-04-13 19:02:44ZAenweb47malesecondarylow897.995012True53134True
1f2e20ad1-1ed1-4e33-8beb-5dd0ba23715b685642024-03-05 23:30:30KEenandroid67femalebachelorlower_mid935.607966True15235False
2a9345f69-be75-46b9-8cd3-a276ce0a66bd596892024-11-10 03:38:07RWswandroid23malebachelorlow1431.517701True52557False
3b4fa8625-d153-4465-ad73-1c4a48eed2f1207422024-11-19 17:40:58KEenios68femalesecondaryupper_mid448.519416True555310False
4e0ad4bbc-b576-4913-8786-302f06b5e9f7634592024-07-28 04:23:37KEenios34malesecondarylow1179.970734True31335False
\n", + "
" + ], + "text/plain": [ + " response_id respondent_id submitted_at \\\n", + "0 f099c1b6-a4ae-4fb0-ba98-89a81008c424 71615 2024-04-13 19:02:44 \n", + "1 f2e20ad1-1ed1-4e33-8beb-5dd0ba23715b 68564 2024-03-05 23:30:30 \n", + "2 a9345f69-be75-46b9-8cd3-a276ce0a66bd 59689 2024-11-10 03:38:07 \n", + "3 b4fa8625-d153-4465-ad73-1c4a48eed2f1 20742 2024-11-19 17:40:58 \n", + "4 e0ad4bbc-b576-4913-8786-302f06b5e9f7 63459 2024-07-28 04:23:37 \n", + "\n", + " country language device age gender education income_band \\\n", + "0 ZA en web 47 male secondary low \n", + "1 KE en android 67 female bachelor lower_mid \n", + "2 RW sw android 23 male bachelor low \n", + "3 KE en ios 68 female secondary upper_mid \n", + "4 KE en ios 34 male secondary low \n", + "\n", + " completion_seconds attention_passed q_quality q_value q_ease \\\n", + "0 897.995012 True 5 3 1 \n", + "1 935.607966 True 1 5 2 \n", + "2 1431.517701 True 5 2 5 \n", + "3 448.519416 True 5 5 5 \n", + "4 1179.970734 True 3 1 3 \n", + "\n", + " q_support nps is_detractor \n", + "0 3 4 True \n", + "1 3 5 False \n", + "2 5 7 False \n", + "3 3 10 False \n", + "4 3 5 False " + ] + }, + "execution_count": 39, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "\n", + "def build_pandera_schema(CFG):\n", + " if pa is None:\n", + " return None\n", + " cols = {}\n", + " for f in CFG[\"fields\"]:\n", + " t, name = f[\"type\"], f[\"name\"]\n", + " if t == \"int\": cols[name] = pa.Column(int)\n", + " elif t == \"float\": cols[name] = pa.Column(float)\n", + " elif t == \"enum\": cols[name] = pa.Column(object)\n", + " elif t == \"datetime\": cols[name] = pa.Column(object)\n", + " elif t == \"uuid4\": cols[name] = pa.Column(object)\n", + " elif t == \"bool\": cols[name] = pa.Column(bool)\n", + " else: cols[name] = pa.Column(object)\n", + " return pa.DataFrameSchema(cols) if pa is not None else None\n", + "\n", + "def validate_df(df, CFG):\n", + " schema = build_pandera_schema(CFG)\n", + " if schema is None:\n", + " return df, {\"engine\":\"basic\",\"valid_rows\": len(df), \"invalid_rows\": 0}\n", + " try:\n", + " v = schema.validate(df, lazy=True)\n", + " return v, {\"engine\":\"pandera\",\"valid_rows\": len(v), \"invalid_rows\": 0}\n", + " except Exception as e:\n", + " print(\"Validation error:\", e)\n", + " return df, {\"engine\":\"pandera\",\"valid_rows\": len(df), \"invalid_rows\": 0, \"notes\": \"Non-strict mode.\"}\n", + "\n", + "validated_rule, report_rule = validate_df(df_rule, CFG)\n", + "print(report_rule)\n", + "validated_rule.head()\n" + ] + }, + { + "cell_type": "markdown", + "id": "d5f1d93a", + "metadata": {}, + "source": [ + "## 5) Save" + ] + }, + { + "cell_type": "code", + "execution_count": 40, + "id": "73626b4c", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Saved: data/survey_rule_20251023T004106Z.csv\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "C:\\Users\\Joshua\\AppData\\Local\\Temp\\ipykernel_27572\\1233117399.py:3: DeprecationWarning: datetime.datetime.utcnow() is deprecated and scheduled for removal in a future version. Use timezone-aware objects to represent datetimes in UTC: datetime.datetime.now(datetime.UTC).\n", + " ts = datetime.utcnow().strftime(\"%Y%m%dT%H%M%SZ\")\n" + ] + } + ], + "source": [ + "\n", + "from pathlib import Path\n", + "out = Path(\"data\"); out.mkdir(exist_ok=True)\n", + "ts = datetime.utcnow().strftime(\"%Y%m%dT%H%M%SZ\")\n", + "csv_path = out / f\"survey_rule_{ts}.csv\"\n", + "validated_rule.to_csv(csv_path, index=False)\n", + "print(\"Saved:\", csv_path.as_posix())\n" + ] + }, + { + "cell_type": "markdown", + "id": "87c89b51", + "metadata": {}, + "source": [ + "## 6) Optional: LLM Generator (JSON mode, retry & strict parsing)" + ] + }, + { + "cell_type": "code", + "execution_count": 41, + "id": "24e94771", + "metadata": {}, + "outputs": [], + "source": [ + "# Fixed LLM Generation Functions\n", + "def create_survey_prompt(CFG, n_rows=50):\n", + " \"\"\"Create a clear, structured prompt for survey data generation\"\"\"\n", + " fields_desc = []\n", + " for field in CFG['fields']:\n", + " name = field['name']\n", + " field_type = field['type']\n", + " \n", + " if field_type == 'int':\n", + " min_val = field.get('min', 0)\n", + " max_val = field.get('max', 100)\n", + " fields_desc.append(f\" - {name}: integer between {min_val} and {max_val}\")\n", + " elif field_type == 'float':\n", + " min_val = field.get('min', 0.0)\n", + " max_val = field.get('max', 100.0)\n", + " fields_desc.append(f\" - {name}: float between {min_val} and {max_val}\")\n", + " elif field_type == 'enum':\n", + " values = field.get('values', [])\n", + " fields_desc.append(f\" - {name}: one of {values}\")\n", + " elif field_type == 'bool':\n", + " fields_desc.append(f\" - {name}: boolean (true/false)\")\n", + " elif field_type == 'uuid4':\n", + " fields_desc.append(f\" - {name}: UUID string\")\n", + " elif field_type == 'datetime':\n", + " fmt = field.get('fmt', '%Y-%m-%d %H:%M:%S')\n", + " fields_desc.append(f\" - {name}: datetime string in format {fmt}\")\n", + " else:\n", + " fields_desc.append(f\" - {name}: {field_type}\")\n", + " \n", + " prompt = f\"\"\"Generate {n_rows} rows of realistic survey response data.\n", + "\n", + "Schema:\n", + "{chr(10).join(fields_desc)}\n", + "\n", + "CRITICAL REQUIREMENTS:\n", + "- Return a JSON object with a \"responses\" key containing an array\n", + "- Each object in the array must have all required fields\n", + "- Use realistic, diverse values for survey responses\n", + "- No trailing commas\n", + "- No comments or explanations\n", + "\n", + "Output format: JSON object with \"responses\" array containing exactly {n_rows} objects.\n", + "\n", + "Example structure:\n", + "{{\n", + " \"responses\": [\n", + " {{\n", + " \"response_id\": \"uuid-string\",\n", + " \"respondent_id\": 12345,\n", + " \"submitted_at\": \"2024-01-01 12:00:00\",\n", + " \"country\": \"KE\",\n", + " \"language\": \"en\",\n", + " \"device\": \"android\",\n", + " \"age\": 25,\n", + " \"gender\": \"female\",\n", + " \"education\": \"bachelor\",\n", + " \"income_band\": \"upper_mid\",\n", + " \"completion_seconds\": 300.5,\n", + " \"attention_passed\": true,\n", + " \"q_quality\": 4,\n", + " \"q_value\": 3,\n", + " \"q_ease\": 5,\n", + " \"q_support\": 4,\n", + " \"nps\": 8,\n", + " \"is_detractor\": false\n", + " }},\n", + " ...\n", + " ]\n", + "}}\n", + "\n", + "IMPORTANT: Return ONLY the JSON object with \"responses\" key, nothing else.\"\"\"\n", + " \n", + " return prompt\n", + "\n", + "def repair_truncated_json(content):\n", + " \"\"\"Attempt to repair truncated JSON responses\"\"\"\n", + " content = content.strip()\n", + " \n", + " # If it starts with { but doesn't end with }, try to close it\n", + " if content.startswith('{') and not content.endswith('}'):\n", + " # Find the last complete object in the responses array\n", + " responses_start = content.find('\"responses\": [')\n", + " if responses_start != -1:\n", + " # Find the last complete object\n", + " brace_count = 0\n", + " last_complete_pos = -1\n", + " in_string = False\n", + " escape_next = False\n", + " \n", + " for i, char in enumerate(content[responses_start:], responses_start):\n", + " if escape_next:\n", + " escape_next = False\n", + " continue\n", + " \n", + " if char == '\\\\':\n", + " escape_next = True\n", + " continue\n", + " \n", + " if char == '\"' and not escape_next:\n", + " in_string = not in_string\n", + " continue\n", + " \n", + " if not in_string:\n", + " if char == '{':\n", + " brace_count += 1\n", + " elif char == '}':\n", + " brace_count -= 1\n", + " if brace_count == 0:\n", + " last_complete_pos = i\n", + " break\n", + " \n", + " if last_complete_pos != -1:\n", + " # Truncate at the last complete object and close the JSON\n", + " repaired = content[:last_complete_pos + 1] + '\\n ]\\n}'\n", + " print(f\"๐Ÿ”ง Repaired JSON: truncated at position {last_complete_pos}\")\n", + " return repaired\n", + " \n", + " return content\n", + "\n", + "def fixed_llm_generate_batch(CFG, n_rows=50):\n", + " \"\"\"Fixed LLM generation with better prompt and error handling\"\"\"\n", + " if not os.getenv('OPENAI_API_KEY'):\n", + " print(\"No OpenAI API key, using rule-based fallback\")\n", + " tmp = dict(CFG); tmp['rows'] = n_rows\n", + " return generate_rule_based(tmp)\n", + " \n", + " try:\n", + " from openai import OpenAI\n", + " client = OpenAI()\n", + " \n", + " prompt = create_survey_prompt(CFG, n_rows)\n", + " \n", + " print(f\"๐Ÿ”„ Generating {n_rows} survey responses with LLM...\")\n", + " \n", + " # Calculate appropriate max_tokens based on batch size\n", + " # Roughly 200-300 tokens per row, with some buffer\n", + " estimated_tokens = n_rows * 300 + 500 # Buffer for JSON structure\n", + " max_tokens = min(max(estimated_tokens, 2000), 8000) # Between 2k-8k tokens\n", + " \n", + " print(f\"๐Ÿ“Š Using max_tokens: {max_tokens} (estimated: {estimated_tokens})\")\n", + " \n", + " response = client.chat.completions.create(\n", + " model='gpt-4o-mini',\n", + " messages=[\n", + " {'role': 'system', 'content': 'You are a data generation expert. Generate realistic survey data in JSON format. Always return complete, valid JSON.'},\n", + " {'role': 'user', 'content': prompt}\n", + " ],\n", + " temperature=0.3,\n", + " max_tokens=max_tokens,\n", + " response_format={'type': 'json_object'}\n", + " )\n", + " \n", + " content = response.choices[0].message.content\n", + " print(f\"๐Ÿ“ Raw response length: {len(content)} characters\")\n", + " \n", + " # Check if response appears truncated\n", + " if not content.strip().endswith('}') and not content.strip().endswith(']'):\n", + " print(\"โš ๏ธ Response appears truncated, attempting repair...\")\n", + " content = repair_truncated_json(content)\n", + " \n", + " # Try to extract JSON with improved logic\n", + " try:\n", + " data = json.loads(content)\n", + " print(f\"๐Ÿ” Parsed JSON type: {type(data)}\")\n", + " \n", + " if isinstance(data, list):\n", + " df = pd.DataFrame(data)\n", + " print(f\"๐Ÿ“Š Direct array: {len(df)} rows\")\n", + " elif isinstance(data, dict):\n", + " # Check for common keys that might contain the data\n", + " for key in ['responses', 'rows', 'data', 'items', 'records', 'results', 'survey_responses']:\n", + " if key in data and isinstance(data[key], list):\n", + " df = pd.DataFrame(data[key])\n", + " print(f\"๐Ÿ“Š Found data in '{key}': {len(df)} rows\")\n", + " break\n", + " else:\n", + " # If no standard key found, check if all values are lists/objects\n", + " list_keys = [k for k, v in data.items() if isinstance(v, list) and len(v) > 0]\n", + " if list_keys:\n", + " # Use the first list key found\n", + " key = list_keys[0]\n", + " df = pd.DataFrame(data[key])\n", + " print(f\"๐Ÿ“Š Found data in '{key}': {len(df)} rows\")\n", + " else:\n", + " # Try to convert the dict values to a list\n", + " if all(isinstance(v, dict) for v in data.values()):\n", + " df = pd.DataFrame(list(data.values()))\n", + " print(f\"๐Ÿ“Š Converted dict values: {len(df)} rows\")\n", + " else:\n", + " raise ValueError(f\"Unexpected JSON structure: {list(data.keys())}\")\n", + " else:\n", + " raise ValueError(f\"Unexpected JSON type: {type(data)}\")\n", + " \n", + " if len(df) == n_rows:\n", + " print(f\"โœ… Successfully generated {len(df)} survey responses\")\n", + " return df\n", + " else:\n", + " print(f\"โš ๏ธ Generated {len(df)} rows, expected {n_rows}\")\n", + " if len(df) > 0:\n", + " return df\n", + " else:\n", + " raise ValueError(\"No data generated\")\n", + " \n", + " except json.JSONDecodeError as e:\n", + " print(f\"โŒ JSON parsing failed: {str(e)}\")\n", + " # Try the improved extract_strict_json function\n", + " try:\n", + " data = extract_strict_json(content)\n", + " df = pd.DataFrame(data)\n", + " print(f\"โœ… Recovered with strict parsing: {len(df)} rows\")\n", + " return df\n", + " except Exception as e2:\n", + " print(f\"โŒ Strict parsing also failed: {str(e2)}\")\n", + " # Print a sample of the content for debugging\n", + " print(f\"๐Ÿ” Content sample: {content[:500]}...\")\n", + " raise e2\n", + " \n", + " except Exception as e:\n", + " print(f'โŒ LLM error, fallback to rule-based mock: {str(e)}')\n", + " tmp = dict(CFG); tmp['rows'] = n_rows\n", + " return generate_rule_based(tmp)\n", + "\n", + "def fixed_generate_llm(CFG, total_rows=200, batch_size=50):\n", + " \"\"\"Fixed LLM generation with adaptive batch processing\"\"\"\n", + " print(f\"๐Ÿš€ Generating {total_rows} survey responses with adaptive batching\")\n", + " \n", + " # Adaptive batch sizing based on total rows\n", + " if total_rows <= 20:\n", + " optimal_batch_size = min(batch_size, total_rows)\n", + " elif total_rows <= 50:\n", + " optimal_batch_size = min(15, batch_size)\n", + " elif total_rows <= 100:\n", + " optimal_batch_size = min(10, batch_size)\n", + " else:\n", + " optimal_batch_size = min(8, batch_size)\n", + " \n", + " print(f\"๐Ÿ“Š Using optimal batch size: {optimal_batch_size}\")\n", + " \n", + " all_dataframes = []\n", + " remaining = total_rows\n", + " \n", + " while remaining > 0:\n", + " current_batch_size = min(optimal_batch_size, remaining)\n", + " print(f\"\\n๐Ÿ“ฆ Processing batch: {current_batch_size} rows (remaining: {remaining})\")\n", + " \n", + " try:\n", + " batch_df = fixed_llm_generate_batch(CFG, current_batch_size)\n", + " all_dataframes.append(batch_df)\n", + " remaining -= len(batch_df)\n", + " \n", + " # Small delay between batches to avoid rate limits\n", + " if remaining > 0:\n", + " time.sleep(1.5)\n", + " \n", + " except Exception as e:\n", + " print(f\"โŒ Batch failed: {str(e)}\")\n", + " print(f\"๐Ÿ”„ Retrying with smaller batch size...\")\n", + " \n", + " # Try with smaller batch size\n", + " smaller_batch = max(1, current_batch_size // 2)\n", + " if smaller_batch < current_batch_size:\n", + " try:\n", + " print(f\"๐Ÿ”„ Retrying with {smaller_batch} rows...\")\n", + " batch_df = fixed_llm_generate_batch(CFG, smaller_batch)\n", + " all_dataframes.append(batch_df)\n", + " remaining -= len(batch_df)\n", + " continue\n", + " except Exception as e2:\n", + " print(f\"โŒ Retry also failed: {str(e2)}\")\n", + " \n", + " print(f\"Using rule-based fallback for remaining {remaining} rows\")\n", + " fallback_df = generate_rule_based(CFG, remaining)\n", + " all_dataframes.append(fallback_df)\n", + " break\n", + " \n", + " if all_dataframes:\n", + " result = pd.concat(all_dataframes, ignore_index=True)\n", + " print(f\"โœ… Generated total: {len(result)} survey responses\")\n", + " return result\n", + " else:\n", + " print(\"โŒ No data generated\")\n", + " return pd.DataFrame()\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e1af410e", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "๐Ÿงช Testing LLM generation...\n", + "๐Ÿ”„ Generating 10 survey responses with LLM...\n", + "๐Ÿ“Š Using max_tokens: 3500 (estimated: 3500)\n", + "๐Ÿ“ Raw response length: 5236 characters\n", + "๐Ÿ” Parsed JSON type: \n", + "๐Ÿ“Š Found data in 'responses': 10 rows\n", + "โœ… Successfully generated 10 survey responses\n", + "\n", + "๐Ÿ“Š Generated dataset shape: (10, 18)\n", + "\n", + "๐Ÿ“‹ First few rows:\n", + " response_id respondent_id submitted_at \\\n", + "0 f3e9b9d1-4e9e-4f8a-9b5c-7e3cbb1c4e5e 10234 2023-10-01 14:23:45 \n", + "1 a1c5f6d3-1f5b-4e8a-8c7a-5e2c3f4b8e1b 20456 2023-10-01 15:10:12 \n", + "2 c2b3e4f5-5d6e-4b8a-9f3c-8e1a2f9b4e3c 30567 2023-10-01 16:45:30 \n", + "3 d4e5f6b7-6e8f-4b9a-8c7d-9e2f3c4b5e6f 40678 2023-10-01 17:30:00 \n", + "4 e5f6a7b8-7f9a-4c0a-9e2f-1e3c4b5e6f7a 50789 2023-10-01 18:15:15 \n", + "\n", + " country language device age gender education income_band \\\n", + "0 KE en android 29 female bachelor upper_mid \n", + "1 UG sw web 34 male secondary lower_mid \n", + "2 TZ en ios 42 nonbinary diploma high \n", + "3 RW sw android 27 female bachelor upper_mid \n", + "4 NG en web 36 male postgraduate high \n", + "\n", + " completion_seconds attention_passed q_quality q_value q_ease \\\n", + "0 450.0 True 4 5 4 \n", + "1 600.5 True 3 4 3 \n", + "2 720.0 True 5 5 5 \n", + "3 390.0 True 4 4 4 \n", + "4 800.0 True 5 5 5 \n", + "\n", + " q_support nps is_detractor \n", + "0 5 9 False \n", + "1 4 7 False \n", + "2 5 10 False \n", + "3 4 8 False \n", + "4 5 9 False \n", + "\n", + "๐Ÿ“ˆ Data types:\n", + "response_id object\n", + "respondent_id int64\n", + "submitted_at object\n", + "country object\n", + "language object\n", + "device object\n", + "age int64\n", + "gender object\n", + "education object\n", + "income_band object\n", + "completion_seconds float64\n", + "attention_passed bool\n", + "q_quality int64\n", + "q_value int64\n", + "q_ease int64\n", + "q_support int64\n", + "nps int64\n", + "is_detractor bool\n", + "dtype: object\n" + ] + } + ], + "source": [ + "# Test the fixed LLM generation\n", + "print(\"๐Ÿงช Testing LLM generation...\")\n", + "\n", + "# Test with small dataset first\n", + "test_df = fixed_llm_generate_batch(CFG, 10)\n", + "print(f\"\\n๐Ÿ“Š Generated dataset shape: {test_df.shape}\")\n", + "print(f\"\\n๐Ÿ“‹ First few rows:\")\n", + "print(test_df.head())\n", + "print(f\"\\n๐Ÿ“ˆ Data types:\")\n", + "print(test_df.dtypes)\n", + "\n", + "# Debug function to see what the LLM is actually returning\n", + "def debug_llm_response(CFG, n_rows=5):\n", + " \"\"\"Debug function to see raw LLM response\"\"\"\n", + " if not os.getenv('OPENAI_API_KEY'):\n", + " print(\"No OpenAI API key available for debugging\")\n", + " return\n", + " \n", + " try:\n", + " from openai import OpenAI\n", + " client = OpenAI()\n", + " \n", + " prompt = create_survey_prompt(CFG, n_rows)\n", + " \n", + " print(f\"\\n๐Ÿ” DEBUG: Testing with {n_rows} rows\")\n", + " print(f\"๐Ÿ“ Prompt length: {len(prompt)} characters\")\n", + " \n", + " response = client.chat.completions.create(\n", + " model='gpt-4o-mini',\n", + " messages=[\n", + " {'role': 'system', 'content': 'You are a data generation expert. Generate realistic survey data in JSON format.'},\n", + " {'role': 'user', 'content': prompt}\n", + " ],\n", + " temperature=0.3,\n", + " max_tokens=2000,\n", + " response_format={'type': 'json_object'}\n", + " )\n", + " \n", + " content = response.choices[0].message.content\n", + " print(f\"๐Ÿ“ Raw response length: {len(content)} characters\")\n", + " print(f\"๐Ÿ” First 200 characters: {content[:200]}\")\n", + " print(f\"๐Ÿ” Last 200 characters: {content[-200:]}\")\n", + " \n", + " # Try to parse\n", + " try:\n", + " data = json.loads(content)\n", + " print(f\"โœ… JSON parsed successfully\")\n", + " print(f\"๐Ÿ” Data type: {type(data)}\")\n", + " if isinstance(data, dict):\n", + " print(f\"๐Ÿ” Dict keys: {list(data.keys())}\")\n", + " elif isinstance(data, list):\n", + " print(f\"๐Ÿ” List length: {len(data)}\")\n", + " except Exception as e:\n", + " print(f\"โŒ JSON parsing failed: {str(e)}\")\n", + " \n", + " except Exception as e:\n", + " print(f\"โŒ Debug failed: {str(e)}\")\n" + ] + }, + { + "cell_type": "code", + "execution_count": 43, + "id": "75c90739", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "๐Ÿงช Testing the fixed LLM generation...\n", + "๐Ÿ”„ Generating 5 survey responses with LLM...\n", + "๐Ÿ“Š Using max_tokens: 2000 (estimated: 2000)\n", + "๐Ÿ“ Raw response length: 2629 characters\n", + "๐Ÿ” Parsed JSON type: \n", + "๐Ÿ“Š Found data in 'responses': 5 rows\n", + "โœ… Successfully generated 5 survey responses\n", + "\n", + "๐Ÿ“Š Generated dataset shape: (5, 18)\n", + "\n", + "๐Ÿ“‹ First few rows:\n", + " response_id respondent_id submitted_at \\\n", + "0 d8b1c6f3-6f7a-4b4f-9c5f-3a5f8b6e2f1e 12345 2023-10-01 14:30:00 \n", + "1 f3a8e3c1-9b4e-4e5e-9c2b-8f5e3c9b1f3d 67890 2023-10-01 15:00:00 \n", + "2 c9c8e3f1-2b4f-4a6c-8c2e-2a5f3c8e1f2b 54321 2023-10-01 16:15:00 \n", + "3 a5b3c6d2-1e4f-4c5e-9a1f-1f6a7b8e3c9f 98765 2023-10-01 17:45:00 \n", + "4 b8f4c3e2-2e4f-4c5e-8a2f-4c5e3b8e2f1a 13579 2023-10-01 18:30:00 \n", + "\n", + " country language device age gender education income_band \\\n", + "0 KE en android 29 female bachelor upper_mid \n", + "1 UG sw web 34 male diploma lower_mid \n", + "2 TZ en ios 42 nonbinary postgraduate high \n", + "3 RW sw android 27 female secondary low \n", + "4 NG en web 55 male bachelor upper_mid \n", + "\n", + " completion_seconds attention_passed q_quality q_value q_ease \\\n", + "0 420.0 True 5 4 4 \n", + "1 600.0 True 3 3 2 \n", + "2 300.5 True 4 5 4 \n", + "3 720.0 False 2 3 3 \n", + "4 540.0 True 5 5 5 \n", + "\n", + " q_support nps is_detractor \n", + "0 5 9 False \n", + "1 4 5 False \n", + "2 5 10 False \n", + "3 2 3 True \n", + "4 5 8 False \n", + "\n", + "๐Ÿ“ˆ Data types:\n", + "response_id object\n", + "respondent_id int64\n", + "submitted_at object\n", + "country object\n", + "language object\n", + "device object\n", + "age int64\n", + "gender object\n", + "education object\n", + "income_band object\n", + "completion_seconds float64\n", + "attention_passed bool\n", + "q_quality int64\n", + "q_value int64\n", + "q_ease int64\n", + "q_support int64\n", + "nps int64\n", + "is_detractor bool\n", + "dtype: object\n", + "\n", + "โœ… SUCCESS! LLM generation is now working!\n", + "๐Ÿ“Š Generated 5 survey responses using LLM\n" + ] + } + ], + "source": [ + "# Test the fixed implementation\n", + "print(\"๐Ÿงช Testing the fixed LLM generation...\")\n", + "\n", + "# Test with small dataset\n", + "test_df = fixed_llm_generate_batch(CFG, 5)\n", + "print(f\"\\n๐Ÿ“Š Generated dataset shape: {test_df.shape}\")\n", + "print(f\"\\n๐Ÿ“‹ First few rows:\")\n", + "print(test_df.head())\n", + "print(f\"\\n๐Ÿ“ˆ Data types:\")\n", + "print(test_df.dtypes)\n", + "\n", + "if not test_df.empty:\n", + " print(f\"\\nโœ… SUCCESS! LLM generation is now working!\")\n", + " print(f\"๐Ÿ“Š Generated {len(test_df)} survey responses using LLM\")\n", + "else:\n", + " print(f\"\\nโŒ Still having issues with LLM generation\")\n" + ] + }, + { + "cell_type": "code", + "execution_count": 44, + "id": "dd83b842", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "๐Ÿš€ Testing larger dataset generation...\n", + "๐Ÿš€ Generating 100 survey responses with adaptive batching\n", + "๐Ÿ“Š Using optimal batch size: 10\n", + "\n", + "๐Ÿ“ฆ Processing batch: 10 rows (remaining: 100)\n", + "๐Ÿ”„ Generating 10 survey responses with LLM...\n", + "๐Ÿ“Š Using max_tokens: 3500 (estimated: 3500)\n", + "๐Ÿ“ Raw response length: 5238 characters\n", + "๐Ÿ” Parsed JSON type: \n", + "๐Ÿ“Š Found data in 'responses': 10 rows\n", + "โœ… Successfully generated 10 survey responses\n", + "\n", + "๐Ÿ“ฆ Processing batch: 10 rows (remaining: 90)\n", + "๐Ÿ”„ Generating 10 survey responses with LLM...\n", + "๐Ÿ“Š Using max_tokens: 3500 (estimated: 3500)\n", + "๐Ÿ“ Raw response length: 5235 characters\n", + "๐Ÿ” Parsed JSON type: \n", + "๐Ÿ“Š Found data in 'responses': 10 rows\n", + "โœ… Successfully generated 10 survey responses\n", + "\n", + "๐Ÿ“ฆ Processing batch: 10 rows (remaining: 80)\n", + "๐Ÿ”„ Generating 10 survey responses with LLM...\n", + "๐Ÿ“Š Using max_tokens: 3500 (estimated: 3500)\n", + "๐Ÿ“ Raw response length: 5232 characters\n", + "๐Ÿ” Parsed JSON type: \n", + "๐Ÿ“Š Found data in 'responses': 10 rows\n", + "โœ… Successfully generated 10 survey responses\n", + "\n", + "๐Ÿ“ฆ Processing batch: 10 rows (remaining: 70)\n", + "๐Ÿ”„ Generating 10 survey responses with LLM...\n", + "๐Ÿ“Š Using max_tokens: 3500 (estimated: 3500)\n", + "๐Ÿ“ Raw response length: 5239 characters\n", + "๐Ÿ” Parsed JSON type: \n", + "๐Ÿ“Š Found data in 'responses': 10 rows\n", + "โœ… Successfully generated 10 survey responses\n", + "\n", + "๐Ÿ“ฆ Processing batch: 10 rows (remaining: 60)\n", + "๐Ÿ”„ Generating 10 survey responses with LLM...\n", + "๐Ÿ“Š Using max_tokens: 3500 (estimated: 3500)\n", + "๐Ÿ“ Raw response length: 5238 characters\n", + "๐Ÿ” Parsed JSON type: \n", + "๐Ÿ“Š Found data in 'responses': 10 rows\n", + "โœ… Successfully generated 10 survey responses\n", + "\n", + "๐Ÿ“ฆ Processing batch: 10 rows (remaining: 50)\n", + "๐Ÿ”„ Generating 10 survey responses with LLM...\n", + "๐Ÿ“Š Using max_tokens: 3500 (estimated: 3500)\n", + "๐Ÿ“ Raw response length: 5236 characters\n", + "๐Ÿ” Parsed JSON type: \n", + "๐Ÿ“Š Found data in 'responses': 10 rows\n", + "โœ… Successfully generated 10 survey responses\n", + "\n", + "๐Ÿ“ฆ Processing batch: 10 rows (remaining: 40)\n", + "๐Ÿ”„ Generating 10 survey responses with LLM...\n", + "๐Ÿ“Š Using max_tokens: 3500 (estimated: 3500)\n", + "๐Ÿ“ Raw response length: 5229 characters\n", + "๐Ÿ” Parsed JSON type: \n", + "๐Ÿ“Š Found data in 'responses': 10 rows\n", + "โœ… Successfully generated 10 survey responses\n", + "\n", + "๐Ÿ“ฆ Processing batch: 10 rows (remaining: 30)\n", + "๐Ÿ”„ Generating 10 survey responses with LLM...\n", + "๐Ÿ“Š Using max_tokens: 3500 (estimated: 3500)\n", + "๐Ÿ“ Raw response length: 5244 characters\n", + "๐Ÿ” Parsed JSON type: \n", + "๐Ÿ“Š Found data in 'responses': 10 rows\n", + "โœ… Successfully generated 10 survey responses\n", + "\n", + "๐Ÿ“ฆ Processing batch: 10 rows (remaining: 20)\n", + "๐Ÿ”„ Generating 10 survey responses with LLM...\n", + "๐Ÿ“Š Using max_tokens: 3500 (estimated: 3500)\n", + "๐Ÿ“ Raw response length: 5234 characters\n", + "๐Ÿ” Parsed JSON type: \n", + "๐Ÿ“Š Found data in 'responses': 10 rows\n", + "โœ… Successfully generated 10 survey responses\n", + "\n", + "๐Ÿ“ฆ Processing batch: 10 rows (remaining: 10)\n", + "๐Ÿ”„ Generating 10 survey responses with LLM...\n", + "๐Ÿ“Š Using max_tokens: 3500 (estimated: 3500)\n", + "๐Ÿ“ Raw response length: 5238 characters\n", + "๐Ÿ” Parsed JSON type: \n", + "๐Ÿ“Š Found data in 'responses': 10 rows\n", + "โœ… Successfully generated 10 survey responses\n", + "โœ… Generated total: 100 survey responses\n", + "\n", + "๐Ÿ“Š Large dataset shape: (100, 18)\n", + "\n", + "๐Ÿ“ˆ Summary statistics:\n", + " respondent_id age completion_seconds q_quality q_value \\\n", + "count 100.000000 100.000000 100.000000 100.000000 100.000000 \n", + "mean 33513.700000 34.070000 588.525000 3.740000 3.910000 \n", + "std 29233.800863 7.835757 230.530212 1.001211 0.995901 \n", + "min 10001.000000 22.000000 120.500000 2.000000 2.000000 \n", + "25% 10009.000000 28.000000 420.375000 3.000000 3.000000 \n", + "50% 15122.500000 33.000000 600.000000 4.000000 4.000000 \n", + "75% 55955.750000 39.250000 720.000000 5.000000 5.000000 \n", + "max 98765.000000 50.000000 1500.000000 5.000000 5.000000 \n", + "\n", + " q_ease q_support nps \n", + "count 100.000000 100.000000 100.000000 \n", + "mean 3.900000 3.910000 6.990000 \n", + "std 0.937437 0.985706 2.333312 \n", + "min 2.000000 2.000000 2.000000 \n", + "25% 3.000000 3.000000 5.000000 \n", + "50% 4.000000 4.000000 7.000000 \n", + "75% 5.000000 5.000000 9.000000 \n", + "max 5.000000 5.000000 10.000000 \n", + "๐Ÿ’พ Saved: data\\survey_llm_fixed_20251023T005139Z.csv\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "C:\\Users\\Joshua\\AppData\\Local\\Temp\\ipykernel_27572\\2716383900.py:12: DeprecationWarning: datetime.datetime.utcnow() is deprecated and scheduled for removal in a future version. Use timezone-aware objects to represent datetimes in UTC: datetime.datetime.now(datetime.UTC).\n", + " ts = datetime.utcnow().strftime(\"%Y%m%dT%H%M%SZ\")\n" + ] + } + ], + "source": [ + "#Test larger dataset generation \n", + "print(\"๐Ÿš€ Testing larger dataset generation...\")\n", + "large_df = fixed_generate_llm(CFG, total_rows=100, batch_size=25)\n", + "if not large_df.empty:\n", + " print(f\"\\n๐Ÿ“Š Large dataset shape: {large_df.shape}\")\n", + " print(f\"\\n๐Ÿ“ˆ Summary statistics:\")\n", + " print(large_df.describe())\n", + " \n", + " # Save the results\n", + " from pathlib import Path\n", + " out = Path(\"data\"); out.mkdir(exist_ok=True)\n", + " ts = datetime.utcnow().strftime(\"%Y%m%dT%H%M%SZ\")\n", + " csv_path = out / f\"survey_llm_fixed_{ts}.csv\"\n", + " large_df.to_csv(csv_path, index=False)\n", + " print(f\"๐Ÿ’พ Saved: {csv_path}\")\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6029d3e2", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "LLM available: True\n" + ] + } + ], + "source": [ + "\n", + "def build_json_schema(CFG):\n", + " schema = {'type':'array','items':{'type':'object','properties':{},'required':[]}}\n", + " props = schema['items']['properties']; req = schema['items']['required']\n", + " for f in CFG['fields']:\n", + " name, t = f['name'], f['type']\n", + " req.append(name)\n", + " if t in ('int','float'): props[name] = {'type':'number' if t=='float' else 'integer'}\n", + " elif t == 'enum': props[name] = {'type':'string','enum': f['values']}\n", + " elif t in ('uuid4','datetime'): props[name] = {'type':'string'}\n", + " elif t == 'bool': props[name] = {'type':'boolean'}\n", + " else: props[name] = {'type':'string'}\n", + " return schema\n", + "\n", + "PROMPT_PREAMBLE = (\n", + " \"You are a data generator. Return ONLY JSON. \"\n", + " \"Respond as a JSON object with key 'rows' whose value is an array of exactly N objects. \"\n", + " \"No prose, no code fences, no trailing commas.\"\n", + ")\n", + "\n", + "def render_prompt(CFG, n_rows=100):\n", + " minimal_cfg = {'fields': []}\n", + " for f in CFG['fields']:\n", + " base = {k: f[k] for k in ['name','type'] if k in f}\n", + " if 'min' in f and 'max' in f: base.update({'min': f['min'], 'max': f['max']})\n", + " if 'values' in f: base.update({'values': f['values']})\n", + " if 'fmt' in f: base.update({'fmt': f['fmt']})\n", + " minimal_cfg['fields'].append(base)\n", + " return {\n", + " 'preamble': PROMPT_PREAMBLE,\n", + " 'n_rows': n_rows,\n", + " 'schema': build_json_schema(CFG),\n", + " 'constraints': minimal_cfg,\n", + " 'instruction': f\"Return ONLY this structure: {{'rows': [ ... exactly {n_rows} objects ... ]}}\"\n", + " }\n", + "\n", + "def parse_llm_json_to_df(raw: str) -> pd.DataFrame:\n", + " try:\n", + " obj = json.loads(raw)\n", + " if isinstance(obj, dict) and isinstance(obj.get('rows'), list):\n", + " return pd.DataFrame(obj['rows'])\n", + " except Exception:\n", + " pass\n", + " data = extract_strict_json(raw)\n", + " return pd.DataFrame(data)\n", + "\n", + "USE_LLM = bool(os.getenv('OPENAI_API_KEY'))\n", + "print('LLM available:', USE_LLM)\n", + "\n", + "def llm_generate_batch(CFG, n_rows=50):\n", + " if USE_LLM:\n", + " try:\n", + " from openai import OpenAI\n", + " client = OpenAI()\n", + " prompt = json.dumps(render_prompt(CFG, n_rows))\n", + " resp = client.chat.completions.create(\n", + " model='gpt-4o-mini',\n", + " response_format={'type': 'json_object'},\n", + " messages=[\n", + " {'role':'system','content':'You output strict JSON only.'},\n", + " {'role':'user','content': prompt}\n", + " ],\n", + " temperature=0.2,\n", + " max_tokens=8192,\n", + " )\n", + " raw = resp.choices[0].message.content\n", + " try:\n", + " return parse_llm_json_to_df(raw)\n", + " except Exception:\n", + " stricter = (\n", + " prompt\n", + " + \"\\nReturn ONLY a JSON object structured as: \"\n", + " + \"{\\\"rows\\\": [ ... exactly N objects ... ]}. \"\n", + " + \"No prose, no explanations.\"\n", + " )\n", + " resp2 = client.chat.completions.create(\n", + " model='gpt-4o-mini',\n", + " response_format={'type': 'json_object'},\n", + " messages=[\n", + " {'role':'system','content':'You output strict JSON only.'},\n", + " {'role':'user','content': stricter}\n", + " ],\n", + " temperature=0.2,\n", + " max_tokens=8192,\n", + " )\n", + " raw2 = resp2.choices[0].message.content\n", + " return parse_llm_json_to_df(raw2)\n", + " except Exception as e:\n", + " print('LLM error, fallback to rule-based mock:', e)\n", + " tmp = dict(CFG); tmp['rows'] = n_rows\n", + " return generate_rule_based(tmp)\n", + "\n", + "def generate_llm(CFG, total_rows=200, batch_size=50):\n", + " dfs = []; remaining = total_rows\n", + " while remaining > 0:\n", + " b = min(batch_size, remaining)\n", + " dfs.append(llm_generate_batch(CFG, n_rows=b))\n", + " remaining -= b\n", + " time.sleep(0.2)\n", + " return pd.concat(dfs, ignore_index=True)\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "2e759087", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "LLM error, fallback to rule-based mock: No JSON array found in model output.\n" + ] + }, + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
response_idrespondent_idsubmitted_atcountrylanguagedeviceagegendereducationincome_bandcompletion_secondsattention_passedq_qualityq_valueq_easeq_supportnpsis_detractor
09e7811bd-27ee-4b7c-9b7a-c98441e337f0401602024-08-18 19:10:06KEswweb28malesecondarylower_mid1800.000000True43334True
185ec8b90-5468-4880-8309-e325da14d877553812025-01-24 12:21:13TZswios23femalebachelorhigh431.412783True32344False
2498dff10-040f-4206-8170-dfce0d5a69f0483382025-07-15 22:21:54TZenios49malebachelorlow1800.000000True23313False
3ddf11d94-5d6e-4322-9811-4e763f5ed46b599252025-01-27 00:16:57KEenweb22malebachelorupper_mid656.050991True44135False
42ef22a0c-fd13-4798-9276-f43831b8f7bc689932024-08-19 04:21:49KEenandroid40malesecondarylower_mid1553.938944True22515False
\n", + "
" + ], + "text/plain": [ + " response_id respondent_id submitted_at \\\n", + "0 9e7811bd-27ee-4b7c-9b7a-c98441e337f0 40160 2024-08-18 19:10:06 \n", + "1 85ec8b90-5468-4880-8309-e325da14d877 55381 2025-01-24 12:21:13 \n", + "2 498dff10-040f-4206-8170-dfce0d5a69f0 48338 2025-07-15 22:21:54 \n", + "3 ddf11d94-5d6e-4322-9811-4e763f5ed46b 59925 2025-01-27 00:16:57 \n", + "4 2ef22a0c-fd13-4798-9276-f43831b8f7bc 68993 2024-08-19 04:21:49 \n", + "\n", + " country language device age gender education income_band \\\n", + "0 KE sw web 28 male secondary lower_mid \n", + "1 TZ sw ios 23 female bachelor high \n", + "2 TZ en ios 49 male bachelor low \n", + "3 KE en web 22 male bachelor upper_mid \n", + "4 KE en android 40 male secondary lower_mid \n", + "\n", + " completion_seconds attention_passed q_quality q_value q_ease \\\n", + "0 1800.000000 True 4 3 3 \n", + "1 431.412783 True 3 2 3 \n", + "2 1800.000000 True 2 3 3 \n", + "3 656.050991 True 4 4 1 \n", + "4 1553.938944 True 2 2 5 \n", + "\n", + " q_support nps is_detractor \n", + "0 3 4 True \n", + "1 4 4 False \n", + "2 1 3 False \n", + "3 3 5 False \n", + "4 1 5 False " + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "df_llm = generate_llm(CFG, total_rows=100, batch_size=50)\n", + "df_llm.head()" + ] + }, + { + "cell_type": "code", + "execution_count": 46, + "id": "6d4908ad", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "๐Ÿงช Testing improved LLM generation with adaptive batching...\n", + "\n", + "๐Ÿ“ฆ Testing small batch (10 rows)...\n", + "๐Ÿ”„ Generating 10 survey responses with LLM...\n", + "๐Ÿ“Š Using max_tokens: 3500 (estimated: 3500)\n", + "๐Ÿ“ Raw response length: 5233 characters\n", + "๐Ÿ” Parsed JSON type: \n", + "๐Ÿ“Š Found data in 'responses': 10 rows\n", + "โœ… Successfully generated 10 survey responses\n", + "โœ… Small batch result: 10 rows\n", + "\n", + "๐Ÿ“ฆ Testing medium dataset (30 rows) with adaptive batching...\n", + "๐Ÿš€ Generating 30 survey responses with adaptive batching\n", + "๐Ÿ“Š Using optimal batch size: 15\n", + "\n", + "๐Ÿ“ฆ Processing batch: 15 rows (remaining: 30)\n", + "๐Ÿ”„ Generating 15 survey responses with LLM...\n", + "๐Ÿ“Š Using max_tokens: 5000 (estimated: 5000)\n", + "๐Ÿ“ Raw response length: 7839 characters\n", + "๐Ÿ” Parsed JSON type: \n", + "๐Ÿ“Š Found data in 'responses': 15 rows\n", + "โœ… Successfully generated 15 survey responses\n", + "\n", + "๐Ÿ“ฆ Processing batch: 15 rows (remaining: 15)\n", + "๐Ÿ”„ Generating 15 survey responses with LLM...\n", + "๐Ÿ“Š Using max_tokens: 5000 (estimated: 5000)\n", + "๐Ÿ“ Raw response length: 7841 characters\n", + "๐Ÿ” Parsed JSON type: \n", + "๐Ÿ“Š Found data in 'responses': 15 rows\n", + "โœ… Successfully generated 15 survey responses\n", + "โœ… Generated total: 30 survey responses\n", + "โœ… Medium dataset result: 30 rows\n", + "\n", + "๐Ÿ“Š Dataset shape: (30, 18)\n", + "\n", + "๐Ÿ“‹ First few rows:\n", + " response_id respondent_id submitted_at \\\n", + "0 d1e5c4a3-4b1f-4f6b-8f9e-9f1e1f2e3d4c 10001 2023-10-01 14:30:00 \n", + "1 c2b1d4a6-7f8e-4c5c-9d8f-1e2c3b4a5e6f 10002 2023-10-01 15:00:00 \n", + "2 e3f2c5b7-8a2d-4c8e-9f1b-2c3d4e5f6a7b 10003 2023-10-01 15:30:00 \n", + "3 f4a5b6c8-9d3e-4b1f-9f2c-3d4e5f6a7b8c 10004 2023-10-01 16:00:00 \n", + "4 g5b6c7d9-0e4f-4b2a-8f3d-4e5f6a7b8c9d 10005 2023-10-01 16:30:00 \n", + "\n", + " country language device age gender education income_band \\\n", + "0 KE en android 28 female bachelor upper_mid \n", + "1 UG sw web 35 male diploma lower_mid \n", + "2 TZ en ios 42 nonbinary postgraduate high \n", + "3 RW sw web 29 female secondary upper_mid \n", + "4 NG en android 50 male bachelor high \n", + "\n", + " completion_seconds attention_passed q_quality q_value q_ease \\\n", + "0 450.0 True 5 4 5 \n", + "1 600.0 True 3 2 4 \n", + "2 720.0 True 4 5 4 \n", + "3 300.0 True 3 3 3 \n", + "4 540.0 True 5 5 5 \n", + "\n", + " q_support nps is_detractor \n", + "0 4 9 False \n", + "1 3 5 False \n", + "2 5 10 False \n", + "3 4 6 False \n", + "4 5 10 False \n", + "๐Ÿ’พ Saved: data\\survey_adaptive_batch_20251023T005927Z.csv\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "C:\\Users\\Joshua\\AppData\\Local\\Temp\\ipykernel_27572\\1770033334.py:22: DeprecationWarning: datetime.datetime.utcnow() is deprecated and scheduled for removal in a future version. Use timezone-aware objects to represent datetimes in UTC: datetime.datetime.now(datetime.UTC).\n", + " ts = datetime.utcnow().strftime(\"%Y%m%dT%H%M%SZ\")\n" + ] + } + ], + "source": [ + "# Test the improved LLM generation with adaptive batching\n", + "print(\"๐Ÿงช Testing improved LLM generation with adaptive batching...\")\n", + "\n", + "# Test with smaller dataset first\n", + "print(\"\\n๐Ÿ“ฆ Testing small batch (10 rows)...\")\n", + "small_df = fixed_llm_generate_batch(CFG, 10)\n", + "print(f\"โœ… Small batch result: {len(small_df)} rows\")\n", + "\n", + "# Test with medium dataset using adaptive batching\n", + "print(\"\\n๐Ÿ“ฆ Testing medium dataset (30 rows) with adaptive batching...\")\n", + "medium_df = fixed_generate_llm(CFG, total_rows=30, batch_size=15)\n", + "print(f\"โœ… Medium dataset result: {len(medium_df)} rows\")\n", + "\n", + "if not medium_df.empty:\n", + " print(f\"\\n๐Ÿ“Š Dataset shape: {medium_df.shape}\")\n", + " print(f\"\\n๐Ÿ“‹ First few rows:\")\n", + " print(medium_df.head())\n", + " \n", + " # Save the results\n", + " from pathlib import Path\n", + " out = Path(\"data\"); out.mkdir(exist_ok=True)\n", + " ts = datetime.utcnow().strftime(\"%Y%m%dT%H%M%SZ\")\n", + " csv_path = out / f\"survey_adaptive_batch_{ts}.csv\"\n", + " medium_df.to_csv(csv_path, index=False)\n", + " print(f\"๐Ÿ’พ Saved: {csv_path}\")\n", + "else:\n", + " print(\"โŒ Medium dataset generation failed\")\n" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": ".venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.12" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}