293 lines
9.5 KiB
Plaintext
293 lines
9.5 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "63356928",
|
|
"metadata": {},
|
|
"source": [
|
|
"# Initial Note\n",
|
|
"After running experiments in Colab using open-source models from Hugging Face, I decided to do the exercise with OpenAI. The reason is that Llama 3.2 frequently did not follow the prompts correctly, leading to inconsistencies and poor performance. Additionally, using larger models significantly increased processing time, making them less practical for this task.\n",
|
|
"\n",
|
|
"The code from this notebook will be reorganized in modules for the final Demo."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "5c12f081",
|
|
"metadata": {},
|
|
"source": [
|
|
"# Module to generate syntethic data"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "2389d798",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"\n",
|
|
"import re \n",
|
|
"\n",
|
|
"def _clean_json_output(raw_text: str) -> str:\n",
|
|
" \"\"\"\n",
|
|
" Limpia la salida de OpenAI para convertirla en JSON válido:\n",
|
|
" - Mantiene las comillas de claves sin tocar.\n",
|
|
" - Escapa solo las comillas dobles dentro de los strings de valores.\n",
|
|
" - Escapa \\n, \\r, \\t.\n",
|
|
" - Remueve code fences y HTML.\n",
|
|
" - Asegura que el array comience con [ y termine con ].\n",
|
|
" - Elimina comas finales.\n",
|
|
" \"\"\"\n",
|
|
" text = raw_text.strip()\n",
|
|
" \n",
|
|
" # Remover code fences y HTML\n",
|
|
" text = re.sub(r\"```(?:json)?\", \"\", text)\n",
|
|
" text = re.sub(r\"</?[^>]+>\", \"\", text)\n",
|
|
" \n",
|
|
" # Escapar comillas dobles dentro de valores de Comment\n",
|
|
" def escape_quotes_in_values(match):\n",
|
|
" value = match.group(1)\n",
|
|
" value = value.replace('\"', r'\\\"') # solo dentro del valor\n",
|
|
" value = value.replace('\\n', r'\\n').replace('\\r', r'\\r').replace('\\t', r'\\t')\n",
|
|
" return f'\"{value}\"'\n",
|
|
" \n",
|
|
" text = re.sub(r'\"(.*?)\"', escape_quotes_in_values, text)\n",
|
|
" \n",
|
|
" # Asegurar que empieza y termina con []\n",
|
|
" if not text.startswith('['):\n",
|
|
" text = '[' + text\n",
|
|
" if not text.endswith(']'):\n",
|
|
" text += ']'\n",
|
|
" \n",
|
|
" # Eliminar comas finales antes de cerrar corchetes\n",
|
|
" text = re.sub(r',\\s*]', ']', text)\n",
|
|
" \n",
|
|
" return text\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "75bfad6f",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"import pandas as pd\n",
|
|
"import json\n",
|
|
"import openai\n",
|
|
"import tempfile\n",
|
|
"\n",
|
|
"\n",
|
|
"def generate_synthetic_data_openai(\n",
|
|
" system_prompt: str,\n",
|
|
" user_prompt: str,\n",
|
|
" reference_file=None,\n",
|
|
" openai_model=\"gpt-4o-mini\",\n",
|
|
" max_tokens=2048,\n",
|
|
" temperature=0.0\n",
|
|
"):\n",
|
|
" \"\"\"\n",
|
|
" Genera datos sintéticos y devuelve el DataFrame y la ruta de un CSV temporal.\n",
|
|
" \"\"\"\n",
|
|
" # Preparar prompt completo\n",
|
|
" if reference_file:\n",
|
|
" if isinstance(reference_file, str):\n",
|
|
" df_ref = pd.read_csv(reference_file)\n",
|
|
" else:\n",
|
|
" df_ref = pd.read_csv(reference_file)\n",
|
|
" reference_data = df_ref.to_dict(orient=\"records\")\n",
|
|
" user_prompt_full = (\n",
|
|
" f\"{user_prompt}\\nFollow the structure and distribution of the reference data, \"\n",
|
|
" f\"but do NOT copy any exact values:\\n{reference_data}\"\n",
|
|
" )\n",
|
|
" else:\n",
|
|
" user_prompt_full = user_prompt\n",
|
|
"\n",
|
|
" # Llamar a OpenAI\n",
|
|
" response = openai.chat.completions.create(\n",
|
|
" model=openai_model,\n",
|
|
" messages=[\n",
|
|
" {\"role\": \"system\", \"content\": system_prompt},\n",
|
|
" {\"role\": \"user\", \"content\": user_prompt_full},\n",
|
|
" ],\n",
|
|
" temperature=temperature,\n",
|
|
" max_tokens=max_tokens,\n",
|
|
" )\n",
|
|
"\n",
|
|
" raw_text = response.choices[0].message.content\n",
|
|
" cleaned_json = _clean_json_output(raw_text)\n",
|
|
"\n",
|
|
" # Parsear JSON\n",
|
|
" try:\n",
|
|
" data = json.loads(cleaned_json)\n",
|
|
" except json.JSONDecodeError as e:\n",
|
|
" raise ValueError(f\"JSON inválido generado. Error: {e}\\nOutput truncado: {cleaned_json[:500]}\")\n",
|
|
"\n",
|
|
" df = pd.DataFrame(data)\n",
|
|
"\n",
|
|
" # Guardar CSV temporal\n",
|
|
" tmp_file = tempfile.NamedTemporaryFile(delete=False, suffix=\".csv\")\n",
|
|
" df.to_csv(tmp_file.name, index=False)\n",
|
|
" tmp_file.close()\n",
|
|
"\n",
|
|
" return df, tmp_file.name\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "91af1eb5",
|
|
"metadata": {},
|
|
"source": [
|
|
"# Default prompts"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "792d1555",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"SYSTEM_PROMPT = \"\"\"\n",
|
|
"You are a precise synthetic data generator. Your only task is to output valid JSON arrays of dictionaries.\n",
|
|
"\n",
|
|
"Rules:\n",
|
|
"1. Output a single JSON array starting with '[' and ending with ']'.\n",
|
|
"2. Do not include markdown, code fences, or explanatory text — only the JSON.\n",
|
|
"3. Keep all columns exactly as specified; do not add or remove fields (index must be omitted).\n",
|
|
"4. Respect data types: text, number, date, boolean, etc.\n",
|
|
"5. Ensure internal consistency and realistic variation.\n",
|
|
"6. If a reference table is provided, generate data with similar statistical distributions for numerical and categorical variables, \n",
|
|
" but never copy exact rows. Each row must be independent and new.\n",
|
|
"7. For personal information (names, ages, addresses, IDs), ensure diversity and realism — individual values may be reused to maintain realism, \n",
|
|
" but never reuse or slightly modify entire reference rows.\n",
|
|
"8. Escape all internal double quotes in strings with a backslash (\\\").\n",
|
|
"9. Replace any single quotes in strings with double quotes.\n",
|
|
"10. Escape newline (\\n), tab (\\t), or carriage return (\\r) characters as \\\\n, \\\\t, \\\\r inside strings.\n",
|
|
"11. Remove any trailing commas before closing brackets.\n",
|
|
"12. Do not include any reference data or notes about it in the output.\n",
|
|
"13. The output must always be valid JSON parseable by standard JSON parsers.\n",
|
|
"\"\"\"\n",
|
|
"\n",
|
|
"USER_PROMPT = \"\"\"\n",
|
|
"Generate exactly 15 rows of synthetic data following all the rules above. \n",
|
|
"Ensure that all strings are safe for JSON parsing and ready to convert to a pandas DataFrame.\n",
|
|
"\"\"\"\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "6f9331fa",
|
|
"metadata": {},
|
|
"source": [
|
|
"# Test"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "d38f0afb",
|
|
"metadata": {},
|
|
"source": [
|
|
"For testing our generator, we use the first 50 examples of reddit gaming comments with sentiments dataset.\n",
|
|
"Source: https://www.kaggle.com/datasets/sainitishmitta04/23k-reddit-gaming-comments-with-sentiments-dataset"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "78d94faa",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"\n",
|
|
"df, _ = generate_synthetic_data_openai(SYSTEM_PROMPT, USER_PROMPT, reference_file= \"data/sentiment_reference.csv\")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "0e6b5ebb",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"df"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "015a3110",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"print(df.Comment[0])"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "0ef44876",
|
|
"metadata": {},
|
|
"source": [
|
|
"# Gradio Demo"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "aa4092f4",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"import gradio as gr\n",
|
|
"\n",
|
|
"with gr.Blocks() as demo:\n",
|
|
" gr.Markdown(\"# 🧠 Synthetic Data Generator\")\n",
|
|
"\n",
|
|
" with gr.Row():\n",
|
|
" system_prompt_input = gr.Textbox(label=\"System Prompt\", value=SYSTEM_PROMPT, lines=10)\n",
|
|
"\n",
|
|
" with gr.Row():\n",
|
|
" user_prompt_input = gr.Textbox(label=\"User Prompt\", value=USER_PROMPT, lines=5)\n",
|
|
"\n",
|
|
" with gr.Row():\n",
|
|
" reference_input = gr.File(label=\"Reference CSV (optional)\", file_types=[\".csv\"])\n",
|
|
"\n",
|
|
" output_df = gr.DataFrame(label=\"Generated Data\")\n",
|
|
" download_csv = gr.File(label=\"Download CSV\")\n",
|
|
"\n",
|
|
" generate_btn = gr.Button(\"🚀 Generate Data\")\n",
|
|
"\n",
|
|
" generate_btn.click(\n",
|
|
" fn=generate_synthetic_data_openai,\n",
|
|
" inputs=[system_prompt_input, user_prompt_input, reference_input],\n",
|
|
" outputs=[output_df, download_csv]\n",
|
|
" )\n",
|
|
"\n",
|
|
"demo.launch(debug=True)\n"
|
|
]
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": ".venv",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.12.12"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 5
|
|
}
|