Add juan contribution

This commit is contained in:
Jsrodrigue
2025-10-23 15:29:54 +01:00
parent a1a9bc0f95
commit 101b0baf62
18 changed files with 1426 additions and 0 deletions

View File

@@ -0,0 +1,13 @@
import os
import glob
def cleanup_temp_files(temp_dir: str):
"""
Remove all temporary files from the given directory.
"""
files = glob.glob(os.path.join(temp_dir, "*"))
for f in files:
try:
os.remove(f)
except Exception as e:
print(f"[Warning] Could not delete {f}: {e}")

View File

@@ -0,0 +1,45 @@
# -------------------Setup Constants -------------------
N_REFERENCE_ROWS = 64 # Max reference rows per batch for sampling
MAX_TOKENS_MODEL = 128_000 # Max tokens supported by the model, used for batching computations
PROJECT_TEMP_DIR = "temp_plots"
#----------------- Prompts-------------------------------
SYSTEM_PROMPT = """
You are a precise synthetic data generator. Your only task is to output valid JSON arrays of dictionaries.
Rules:
1. Output a single JSON array starting with '[' and ending with ']'.
2. Do not include markdown, code fences, or explanatory text — only the JSON.
3. Keep all columns exactly as specified; do not add or remove fields (index must be omitted).
4. Respect data types: text, number, date, boolean, etc.
5. Ensure internal consistency and realistic variation.
6. If a reference table is provided, generate data with similar statistical distributions for numerical and categorical variables,
but never copy exact rows. Each row must be independent and new.
7. For personal information (names, ages, addresses, IDs), ensure diversity and realism — individual values may be reused to maintain realism,
but never reuse or slightly modify entire reference rows.
8. Escape internal double quotes in strings with a backslash (") for JSON validity.
9. Do NOT replace single quotes in normal text; they should remain as-is.
10. Escape newline (
), tab ( ), or carriage return (
) characters as
, ,
inside strings.
11. Remove any trailing commas before closing brackets.
12. Do not include any reference data or notes about it in the output.
13. The output must always be valid JSON parseable by standard JSON parsers.
14. Don't repeat any exact column neither from the reference or from previous generated data.
15. When using reference data, consider the entire dataset for statistical patterns and diversity;
do not restrict generation to the first rows or the order of the dataset.
16. Introduce slight random variations in numerical values, and choose categorical values randomly according to the distribution,
without repeating rows.
"""
USER_PROMPT = """
Generate exactly 15 rows of synthetic data following all the rules above.
Ensure that all strings are safe for JSON parsing and ready to convert to a pandas DataFrame.
"""

View File

@@ -0,0 +1,108 @@
import os
from typing import List
import pandas as pd
from PIL import Image
from src.constants import MAX_TOKENS_MODEL, N_REFERENCE_ROWS
from src.evaluator import SimpleEvaluator
from src.helpers import hash_row, sample_reference
from src.openai_utils import detect_total_rows_from_prompt, generate_batch
# ------------------- Main Function -------------------
def generate_and_evaluate_data(
system_prompt: str,
user_prompt: str,
temp_dir: str,
reference_file=None,
openai_model: str = "gpt-4o-mini",
max_tokens_model: int = MAX_TOKENS_MODEL,
n_reference_rows: int = N_REFERENCE_ROWS,
):
"""
Generate synthetic data in batches, evaluate against reference data, and save results.
Uses dynamic batching and reference sampling to optimize cost and token usage.
"""
os.makedirs(temp_dir, exist_ok=True)
reference_df = pd.read_csv(reference_file) if reference_file else None
total_rows = detect_total_rows_from_prompt(user_prompt, openai_model)
final_df = pd.DataFrame()
existing_hashes = set()
rows_left = total_rows
iteration = 0
print(f"[Info] Total rows requested: {total_rows}")
# Estimate tokens for the prompt by adding system, user and sample (used once per batch)
prompt_sample = f"{system_prompt} {user_prompt} {sample_reference(reference_df, n_reference_rows)}"
prompt_tokens = max(1, len(prompt_sample) // 4)
# Estimate tokens per row dynamically using a sample
example_sample = sample_reference(reference_df, n_reference_rows)
if example_sample is not None and len(example_sample) > 0:
sample_text = str(example_sample)
tokens_per_row = max(1, len(sample_text) // len(example_sample) // 4)
else:
tokens_per_row = 30 # fallback if no reference
print(f"[Info] Tokens per row estimate: {tokens_per_row}, Prompt tokens: {prompt_tokens}")
# ---------------- Batch Generation Loop ----------------
while rows_left > 0:
iteration += 1
batch_sample = sample_reference(reference_df, n_reference_rows)
batch_size = min(rows_left, max(1, (max_tokens_model - prompt_tokens) // tokens_per_row))
print(f"[Batch {iteration}] Batch size: {batch_size}, Rows left: {rows_left}")
try:
df_batch = generate_batch(
system_prompt, user_prompt, batch_sample, batch_size, openai_model
)
except Exception as e:
print(f"[Error] Batch {iteration} failed: {e}")
break
# Filter duplicates using hash
new_rows = [
row
for _, row in df_batch.iterrows()
if hash_row(row) not in existing_hashes
]
for row in new_rows:
existing_hashes.add(hash_row(row))
final_df = pd.concat([final_df, pd.DataFrame(new_rows)], ignore_index=True)
rows_left = total_rows - len(final_df)
print(
f"[Batch {iteration}] Unique new rows added: {len(new_rows)}, Total so far: {len(final_df)}"
)
if len(new_rows) == 0:
print("[Warning] No new unique rows. Stopping batches.")
break
# ---------------- Evaluation ----------------
report_df, vis_dict = pd.DataFrame(), {}
if reference_df is not None and not final_df.empty:
evaluator = SimpleEvaluator(temp_dir=temp_dir)
evaluator.evaluate(reference_df, final_df)
report_df = evaluator.results_as_dataframe()
vis_dict = evaluator.create_visualizations_temp_dict(reference_df, final_df)
print(f"[Info] Evaluation complete. Report shape: {report_df.shape}")
# ---------------- Collect Images ----------------
all_images: List[Image.Image] = []
for imgs in vis_dict.values():
if isinstance(imgs, list):
all_images.extend([img for img in imgs if img is not None])
# ---------------- Save CSV ----------------
final_csv_path = os.path.join(temp_dir, "synthetic_data.csv")
final_df.to_csv(final_csv_path, index=False)
print(f"[Done] Generated {len(final_df)} rows → saved to {final_csv_path}")
generated_state = {}
return final_df, final_csv_path, report_df, generated_state, all_images

View File

@@ -0,0 +1,142 @@
import seaborn as sns
import matplotlib.pyplot as plt
from typing import List, Dict, Any, Optional
from PIL import Image
import pandas as pd
import os
class SimpleEvaluator:
"""
Evaluates synthetic data against a reference dataset, providing summary statistics and visualizations.
"""
def __init__(self, temp_dir: str = "temp_plots"):
"""
Initialize the evaluator.
Args:
temp_dir (str): Directory to save temporary plot images.
"""
self.temp_dir = temp_dir
os.makedirs(self.temp_dir, exist_ok=True)
def evaluate(self, reference_df: pd.DataFrame, generated_df: pd.DataFrame) -> Dict[str, Any]:
"""
Compare numerical and categorical columns between reference and generated datasets.
"""
self.results: Dict[str, Any] = {}
self.common_cols = list(set(reference_df.columns) & set(generated_df.columns))
for col in self.common_cols:
if pd.api.types.is_numeric_dtype(reference_df[col]):
self.results[col] = {
"type": "numerical",
"ref_mean": reference_df[col].mean(),
"gen_mean": generated_df[col].mean(),
"mean_diff": generated_df[col].mean() - reference_df[col].mean(),
"ref_std": reference_df[col].std(),
"gen_std": generated_df[col].std(),
"std_diff": generated_df[col].std() - reference_df[col].std(),
}
else:
ref_counts = reference_df[col].value_counts(normalize=True)
gen_counts = generated_df[col].value_counts(normalize=True)
overlap = sum(min(ref_counts.get(k, 0), gen_counts.get(k, 0)) for k in ref_counts.index)
self.results[col] = {
"type": "categorical",
"distribution_overlap_pct": round(overlap * 100, 2),
"ref_unique": len(ref_counts),
"gen_unique": len(gen_counts)
}
return self.results
def results_as_dataframe(self) -> pd.DataFrame:
"""
Convert the evaluation results into a pandas DataFrame for display.
"""
rows = []
for col, stats in self.results.items():
if stats["type"] == "numerical":
rows.append({
"Column": col,
"Type": "Numerical",
"Ref Mean/Std": f"{stats['ref_mean']:.2f} / {stats['ref_std']:.2f}",
"Gen Mean/Std": f"{stats['gen_mean']:.2f} / {stats['gen_std']:.2f}",
"Diff": f"Mean diff: {stats['mean_diff']:.2f}, Std diff: {stats['std_diff']:.2f}"
})
else:
rows.append({
"Column": col,
"Type": "Categorical",
"Ref": f"{stats['ref_unique']} unique",
"Gen": f"{stats['gen_unique']} unique",
"Diff": f"Overlap: {stats['distribution_overlap_pct']}%"
})
return pd.DataFrame(rows)
def create_visualizations_temp_dict(
self,
reference_df: pd.DataFrame,
generated_df: pd.DataFrame,
percentage: bool = True
) -> Dict[str, List[Optional[Image.Image]]]:
"""
Create histogram and boxplot visualizations for each column and save them as temporary images.
Handles special characters in column names and category labels.
"""
vis_dict: Dict[str, List[Optional[Image.Image]]] = {}
common_cols = list(set(reference_df.columns) & set(generated_df.columns))
for col in common_cols:
col_safe = str(col).replace("_", r"\_").replace("$", r"\$") # Escape special chars
# ---------------- Histogram ----------------
plt.figure(figsize=(6, 4))
if pd.api.types.is_numeric_dtype(reference_df[col]):
sns.histplot(reference_df[col], color="blue", label="Reference",
stat="percent" if percentage else "count", alpha=0.5)
sns.histplot(generated_df[col], color="orange", label="Generated",
stat="percent" if percentage else "count", alpha=0.5)
else: # Categorical
ref_counts = reference_df[col].value_counts(normalize=percentage)
gen_counts = generated_df[col].value_counts(normalize=percentage)
categories = list(set(ref_counts.index) | set(gen_counts.index))
categories_safe = [str(cat).replace("_", r"\_").replace("$", r"\$") for cat in categories]
ref_vals = [ref_counts.get(cat, 0) for cat in categories]
gen_vals = [gen_counts.get(cat, 0) for cat in categories]
x = range(len(categories))
width = 0.4
plt.bar([i - width/2 for i in x], ref_vals, width=width, color="blue", alpha=0.7, label="Reference")
plt.bar([i + width/2 for i in x], gen_vals, width=width, color="orange", alpha=0.7, label="Generated")
plt.xticks(x, categories_safe, rotation=45, ha="right")
plt.title(f"Histogram comparison for '{col_safe}'", fontsize=12, usetex=False)
plt.legend()
plt.tight_layout()
hist_path = os.path.join(self.temp_dir, f"{col}_hist.png")
plt.savefig(hist_path, bbox_inches='tight')
plt.close()
hist_img = Image.open(hist_path)
# ---------------- Boxplot (numerical only) ----------------
box_img = None
if pd.api.types.is_numeric_dtype(reference_df[col]):
plt.figure(figsize=(6, 4))
df_box = pd.DataFrame({
'Value': pd.concat([reference_df[col], generated_df[col]], ignore_index=True),
'Dataset': ['Reference']*len(reference_df[col]) + ['Generated']*len(generated_df[col])
})
sns.boxplot(x='Dataset', y='Value', data=df_box, palette=['#1f77b4','#ff7f0e'])
plt.title(f"Boxplot comparison for '{col_safe}'", fontsize=12, usetex=False)
plt.tight_layout()
box_path = os.path.join(self.temp_dir, f"{col}_box.png")
plt.savefig(box_path, bbox_inches='tight')
plt.close()
box_img = Image.open(box_path)
vis_dict[col] = [hist_img, box_img]
return vis_dict

View File

@@ -0,0 +1,14 @@
import hashlib
import pandas as pd
def hash_row(row: pd.Series) -> str:
"""Compute MD5 hash for a row to detect duplicates."""
return hashlib.md5(str(tuple(row)).encode()).hexdigest()
def sample_reference(reference_df: pd.DataFrame, n_reference_rows: int) -> list:
"""Return a fresh sample of reference data for batch generation."""
if reference_df is not None and not reference_df.empty:
sample_df = reference_df.sample(min(n_reference_rows, len(reference_df)), replace=False)
return sample_df.to_dict(orient="records")
return []

View File

@@ -0,0 +1,112 @@
import json
import re
import tempfile
import openai
import pandas as pd
import os
from typing import List
# ------------------ JSON Cleaning ------------------
def _clean_json_output(raw_text: str) -> str:
"""
Cleans raw OpenAI output to produce valid JSON.
Escapes only double quotes and control characters.
"""
text = raw_text.strip()
text = re.sub(r"```(?:json)?", "", text)
text = re.sub(r"</?[^>]+>", "", text)
def escape_quotes(match):
value = match.group(1)
value = value.replace('"', r"\"")
value = value.replace("\n", r"\n").replace("\r", r"\r").replace("\t", r"\t")
return f'"{value}"'
text = re.sub(r'"(.*?)"', escape_quotes, text)
if not text.startswith("["):
text = "[" + text
if not text.endswith("]"):
text += "]"
text = re.sub(r",\s*]", "]", text)
return text
# ------------------ Synthetic Data Generation ------------------
def generate_synthetic_data_openai(
system_prompt: str,
full_user_prompt: str,
openai_model: str = "gpt-4o-mini",
max_tokens: int = 16000,
temperature: float = 0.0,
):
"""
Generates synthetic tabular data using OpenAI.
Assumes `full_user_prompt` is already complete with reference data.
"""
response = openai.chat.completions.create(
model=openai_model,
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": full_user_prompt},
],
max_completion_tokens=max_tokens,
temperature=temperature,
)
raw_text = response.choices[0].message.content
cleaned_json = _clean_json_output(raw_text)
try:
data = json.loads(cleaned_json)
except json.JSONDecodeError as e:
raise ValueError(
f"Invalid JSON generated. Error: {e}\nTruncated output: {cleaned_json[:500]}"
)
df = pd.DataFrame(data)
tmp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".csv")
df.to_csv(tmp_file.name, index=False)
tmp_file.close()
return df, tmp_file.name
# ----------------------Mini call to detect the number of rows in the prompt--------------
def detect_total_rows_from_prompt(user_prompt: str, openai_model: str = "gpt-4o-mini") -> int:
"""
Detect the number of rows requested from the user prompt.
Fallback to 20 if detection fails.
"""
mini_prompt = f"""
Extract the number of rows to generate from this instruction:
\"\"\"{user_prompt}\"\"\" Return only the number.
"""
openai.api_key = os.getenv("OPENAI_API_KEY")
try:
response = openai.chat.completions.create(
model=openai_model,
messages=[{"role": "user", "content": mini_prompt}],
temperature=0,
max_tokens=10,
)
text = response.choices[0].message.content.strip()
total_rows = int("".join(filter(str.isdigit, text)))
return max(total_rows, 1)
except Exception:
return 20
# -------------- Function to generate synthetic data in a batch ---------------------
def generate_batch(system_prompt: str, user_prompt: str, reference_sample: List[dict],
batch_size: int, openai_model: str):
"""Generate a single batch of synthetic data using OpenAI."""
full_prompt = f"{user_prompt}\nSample: {reference_sample}\nGenerate exactly {batch_size} rows."
df_batch, _ = generate_synthetic_data_openai(
system_prompt=system_prompt,
full_user_prompt=full_prompt,
openai_model=openai_model,
)
return df_batch

View File

@@ -0,0 +1,13 @@
import pandas as pd
# -------------------------------
# Helper function to display CSV
# -------------------------------
def display_reference_csv(file):
if file is None:
return pd.DataFrame()
try:
df = pd.read_csv(file.name if hasattr(file, "name") else file)
return df
except Exception as e:
return pd.DataFrame({"Error": [str(e)]})