Add juan contribution
This commit is contained in:
@@ -0,0 +1,13 @@
|
||||
import os
|
||||
import glob
|
||||
|
||||
def cleanup_temp_files(temp_dir: str):
|
||||
"""
|
||||
Remove all temporary files from the given directory.
|
||||
"""
|
||||
files = glob.glob(os.path.join(temp_dir, "*"))
|
||||
for f in files:
|
||||
try:
|
||||
os.remove(f)
|
||||
except Exception as e:
|
||||
print(f"[Warning] Could not delete {f}: {e}")
|
||||
@@ -0,0 +1,45 @@
|
||||
# -------------------Setup Constants -------------------
|
||||
N_REFERENCE_ROWS = 64 # Max reference rows per batch for sampling
|
||||
MAX_TOKENS_MODEL = 128_000 # Max tokens supported by the model, used for batching computations
|
||||
PROJECT_TEMP_DIR = "temp_plots"
|
||||
|
||||
|
||||
|
||||
#----------------- Prompts-------------------------------
|
||||
SYSTEM_PROMPT = """
|
||||
You are a precise synthetic data generator. Your only task is to output valid JSON arrays of dictionaries.
|
||||
|
||||
Rules:
|
||||
1. Output a single JSON array starting with '[' and ending with ']'.
|
||||
2. Do not include markdown, code fences, or explanatory text — only the JSON.
|
||||
3. Keep all columns exactly as specified; do not add or remove fields (index must be omitted).
|
||||
4. Respect data types: text, number, date, boolean, etc.
|
||||
5. Ensure internal consistency and realistic variation.
|
||||
6. If a reference table is provided, generate data with similar statistical distributions for numerical and categorical variables,
|
||||
but never copy exact rows. Each row must be independent and new.
|
||||
7. For personal information (names, ages, addresses, IDs), ensure diversity and realism — individual values may be reused to maintain realism,
|
||||
but never reuse or slightly modify entire reference rows.
|
||||
8. Escape internal double quotes in strings with a backslash (") for JSON validity.
|
||||
9. Do NOT replace single quotes in normal text; they should remain as-is.
|
||||
10. Escape newline (
|
||||
), tab ( ), or carriage return (
|
||||
) characters as
|
||||
, ,
|
||||
inside strings.
|
||||
11. Remove any trailing commas before closing brackets.
|
||||
12. Do not include any reference data or notes about it in the output.
|
||||
13. The output must always be valid JSON parseable by standard JSON parsers.
|
||||
14. Don't repeat any exact column neither from the reference or from previous generated data.
|
||||
15. When using reference data, consider the entire dataset for statistical patterns and diversity;
|
||||
do not restrict generation to the first rows or the order of the dataset.
|
||||
16. Introduce slight random variations in numerical values, and choose categorical values randomly according to the distribution,
|
||||
without repeating rows.
|
||||
|
||||
"""
|
||||
|
||||
USER_PROMPT = """
|
||||
Generate exactly 15 rows of synthetic data following all the rules above.
|
||||
Ensure that all strings are safe for JSON parsing and ready to convert to a pandas DataFrame.
|
||||
"""
|
||||
|
||||
|
||||
@@ -0,0 +1,108 @@
|
||||
import os
|
||||
from typing import List
|
||||
|
||||
import pandas as pd
|
||||
from PIL import Image
|
||||
|
||||
from src.constants import MAX_TOKENS_MODEL, N_REFERENCE_ROWS
|
||||
from src.evaluator import SimpleEvaluator
|
||||
from src.helpers import hash_row, sample_reference
|
||||
from src.openai_utils import detect_total_rows_from_prompt, generate_batch
|
||||
|
||||
|
||||
# ------------------- Main Function -------------------
|
||||
def generate_and_evaluate_data(
|
||||
system_prompt: str,
|
||||
user_prompt: str,
|
||||
temp_dir: str,
|
||||
reference_file=None,
|
||||
openai_model: str = "gpt-4o-mini",
|
||||
max_tokens_model: int = MAX_TOKENS_MODEL,
|
||||
n_reference_rows: int = N_REFERENCE_ROWS,
|
||||
):
|
||||
"""
|
||||
Generate synthetic data in batches, evaluate against reference data, and save results.
|
||||
Uses dynamic batching and reference sampling to optimize cost and token usage.
|
||||
"""
|
||||
os.makedirs(temp_dir, exist_ok=True)
|
||||
reference_df = pd.read_csv(reference_file) if reference_file else None
|
||||
total_rows = detect_total_rows_from_prompt(user_prompt, openai_model)
|
||||
|
||||
final_df = pd.DataFrame()
|
||||
existing_hashes = set()
|
||||
rows_left = total_rows
|
||||
iteration = 0
|
||||
|
||||
print(f"[Info] Total rows requested: {total_rows}")
|
||||
|
||||
# Estimate tokens for the prompt by adding system, user and sample (used once per batch)
|
||||
prompt_sample = f"{system_prompt} {user_prompt} {sample_reference(reference_df, n_reference_rows)}"
|
||||
prompt_tokens = max(1, len(prompt_sample) // 4)
|
||||
|
||||
# Estimate tokens per row dynamically using a sample
|
||||
example_sample = sample_reference(reference_df, n_reference_rows)
|
||||
if example_sample is not None and len(example_sample) > 0:
|
||||
sample_text = str(example_sample)
|
||||
tokens_per_row = max(1, len(sample_text) // len(example_sample) // 4)
|
||||
else:
|
||||
tokens_per_row = 30 # fallback if no reference
|
||||
|
||||
print(f"[Info] Tokens per row estimate: {tokens_per_row}, Prompt tokens: {prompt_tokens}")
|
||||
|
||||
# ---------------- Batch Generation Loop ----------------
|
||||
while rows_left > 0:
|
||||
iteration += 1
|
||||
batch_sample = sample_reference(reference_df, n_reference_rows)
|
||||
batch_size = min(rows_left, max(1, (max_tokens_model - prompt_tokens) // tokens_per_row))
|
||||
print(f"[Batch {iteration}] Batch size: {batch_size}, Rows left: {rows_left}")
|
||||
|
||||
try:
|
||||
df_batch = generate_batch(
|
||||
system_prompt, user_prompt, batch_sample, batch_size, openai_model
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"[Error] Batch {iteration} failed: {e}")
|
||||
break
|
||||
|
||||
# Filter duplicates using hash
|
||||
new_rows = [
|
||||
row
|
||||
for _, row in df_batch.iterrows()
|
||||
if hash_row(row) not in existing_hashes
|
||||
]
|
||||
for row in new_rows:
|
||||
existing_hashes.add(hash_row(row))
|
||||
|
||||
final_df = pd.concat([final_df, pd.DataFrame(new_rows)], ignore_index=True)
|
||||
rows_left = total_rows - len(final_df)
|
||||
print(
|
||||
f"[Batch {iteration}] Unique new rows added: {len(new_rows)}, Total so far: {len(final_df)}"
|
||||
)
|
||||
|
||||
if len(new_rows) == 0:
|
||||
print("[Warning] No new unique rows. Stopping batches.")
|
||||
break
|
||||
|
||||
# ---------------- Evaluation ----------------
|
||||
report_df, vis_dict = pd.DataFrame(), {}
|
||||
if reference_df is not None and not final_df.empty:
|
||||
evaluator = SimpleEvaluator(temp_dir=temp_dir)
|
||||
evaluator.evaluate(reference_df, final_df)
|
||||
report_df = evaluator.results_as_dataframe()
|
||||
vis_dict = evaluator.create_visualizations_temp_dict(reference_df, final_df)
|
||||
print(f"[Info] Evaluation complete. Report shape: {report_df.shape}")
|
||||
|
||||
# ---------------- Collect Images ----------------
|
||||
all_images: List[Image.Image] = []
|
||||
for imgs in vis_dict.values():
|
||||
if isinstance(imgs, list):
|
||||
all_images.extend([img for img in imgs if img is not None])
|
||||
|
||||
# ---------------- Save CSV ----------------
|
||||
final_csv_path = os.path.join(temp_dir, "synthetic_data.csv")
|
||||
final_df.to_csv(final_csv_path, index=False)
|
||||
print(f"[Done] Generated {len(final_df)} rows → saved to {final_csv_path}")
|
||||
|
||||
generated_state = {}
|
||||
|
||||
return final_df, final_csv_path, report_df, generated_state, all_images
|
||||
@@ -0,0 +1,142 @@
|
||||
import seaborn as sns
|
||||
import matplotlib.pyplot as plt
|
||||
from typing import List, Dict, Any, Optional
|
||||
from PIL import Image
|
||||
import pandas as pd
|
||||
import os
|
||||
|
||||
class SimpleEvaluator:
|
||||
"""
|
||||
Evaluates synthetic data against a reference dataset, providing summary statistics and visualizations.
|
||||
"""
|
||||
|
||||
def __init__(self, temp_dir: str = "temp_plots"):
|
||||
"""
|
||||
Initialize the evaluator.
|
||||
|
||||
Args:
|
||||
temp_dir (str): Directory to save temporary plot images.
|
||||
"""
|
||||
self.temp_dir = temp_dir
|
||||
os.makedirs(self.temp_dir, exist_ok=True)
|
||||
|
||||
def evaluate(self, reference_df: pd.DataFrame, generated_df: pd.DataFrame) -> Dict[str, Any]:
|
||||
"""
|
||||
Compare numerical and categorical columns between reference and generated datasets.
|
||||
"""
|
||||
self.results: Dict[str, Any] = {}
|
||||
self.common_cols = list(set(reference_df.columns) & set(generated_df.columns))
|
||||
|
||||
for col in self.common_cols:
|
||||
if pd.api.types.is_numeric_dtype(reference_df[col]):
|
||||
self.results[col] = {
|
||||
"type": "numerical",
|
||||
"ref_mean": reference_df[col].mean(),
|
||||
"gen_mean": generated_df[col].mean(),
|
||||
"mean_diff": generated_df[col].mean() - reference_df[col].mean(),
|
||||
"ref_std": reference_df[col].std(),
|
||||
"gen_std": generated_df[col].std(),
|
||||
"std_diff": generated_df[col].std() - reference_df[col].std(),
|
||||
}
|
||||
else:
|
||||
ref_counts = reference_df[col].value_counts(normalize=True)
|
||||
gen_counts = generated_df[col].value_counts(normalize=True)
|
||||
overlap = sum(min(ref_counts.get(k, 0), gen_counts.get(k, 0)) for k in ref_counts.index)
|
||||
self.results[col] = {
|
||||
"type": "categorical",
|
||||
"distribution_overlap_pct": round(overlap * 100, 2),
|
||||
"ref_unique": len(ref_counts),
|
||||
"gen_unique": len(gen_counts)
|
||||
}
|
||||
|
||||
return self.results
|
||||
|
||||
def results_as_dataframe(self) -> pd.DataFrame:
|
||||
"""
|
||||
Convert the evaluation results into a pandas DataFrame for display.
|
||||
"""
|
||||
rows = []
|
||||
for col, stats in self.results.items():
|
||||
if stats["type"] == "numerical":
|
||||
rows.append({
|
||||
"Column": col,
|
||||
"Type": "Numerical",
|
||||
"Ref Mean/Std": f"{stats['ref_mean']:.2f} / {stats['ref_std']:.2f}",
|
||||
"Gen Mean/Std": f"{stats['gen_mean']:.2f} / {stats['gen_std']:.2f}",
|
||||
"Diff": f"Mean diff: {stats['mean_diff']:.2f}, Std diff: {stats['std_diff']:.2f}"
|
||||
})
|
||||
else:
|
||||
rows.append({
|
||||
"Column": col,
|
||||
"Type": "Categorical",
|
||||
"Ref": f"{stats['ref_unique']} unique",
|
||||
"Gen": f"{stats['gen_unique']} unique",
|
||||
"Diff": f"Overlap: {stats['distribution_overlap_pct']}%"
|
||||
})
|
||||
return pd.DataFrame(rows)
|
||||
|
||||
def create_visualizations_temp_dict(
|
||||
self,
|
||||
reference_df: pd.DataFrame,
|
||||
generated_df: pd.DataFrame,
|
||||
percentage: bool = True
|
||||
) -> Dict[str, List[Optional[Image.Image]]]:
|
||||
"""
|
||||
Create histogram and boxplot visualizations for each column and save them as temporary images.
|
||||
Handles special characters in column names and category labels.
|
||||
"""
|
||||
vis_dict: Dict[str, List[Optional[Image.Image]]] = {}
|
||||
common_cols = list(set(reference_df.columns) & set(generated_df.columns))
|
||||
|
||||
for col in common_cols:
|
||||
col_safe = str(col).replace("_", r"\_").replace("$", r"\$") # Escape special chars
|
||||
|
||||
# ---------------- Histogram ----------------
|
||||
plt.figure(figsize=(6, 4))
|
||||
if pd.api.types.is_numeric_dtype(reference_df[col]):
|
||||
sns.histplot(reference_df[col], color="blue", label="Reference",
|
||||
stat="percent" if percentage else "count", alpha=0.5)
|
||||
sns.histplot(generated_df[col], color="orange", label="Generated",
|
||||
stat="percent" if percentage else "count", alpha=0.5)
|
||||
else: # Categorical
|
||||
ref_counts = reference_df[col].value_counts(normalize=percentage)
|
||||
gen_counts = generated_df[col].value_counts(normalize=percentage)
|
||||
categories = list(set(ref_counts.index) | set(gen_counts.index))
|
||||
categories_safe = [str(cat).replace("_", r"\_").replace("$", r"\$") for cat in categories]
|
||||
ref_vals = [ref_counts.get(cat, 0) for cat in categories]
|
||||
gen_vals = [gen_counts.get(cat, 0) for cat in categories]
|
||||
|
||||
x = range(len(categories))
|
||||
width = 0.4
|
||||
plt.bar([i - width/2 for i in x], ref_vals, width=width, color="blue", alpha=0.7, label="Reference")
|
||||
plt.bar([i + width/2 for i in x], gen_vals, width=width, color="orange", alpha=0.7, label="Generated")
|
||||
plt.xticks(x, categories_safe, rotation=45, ha="right")
|
||||
|
||||
plt.title(f"Histogram comparison for '{col_safe}'", fontsize=12, usetex=False)
|
||||
plt.legend()
|
||||
plt.tight_layout()
|
||||
hist_path = os.path.join(self.temp_dir, f"{col}_hist.png")
|
||||
plt.savefig(hist_path, bbox_inches='tight')
|
||||
plt.close()
|
||||
hist_img = Image.open(hist_path)
|
||||
|
||||
# ---------------- Boxplot (numerical only) ----------------
|
||||
box_img = None
|
||||
if pd.api.types.is_numeric_dtype(reference_df[col]):
|
||||
plt.figure(figsize=(6, 4))
|
||||
df_box = pd.DataFrame({
|
||||
'Value': pd.concat([reference_df[col], generated_df[col]], ignore_index=True),
|
||||
'Dataset': ['Reference']*len(reference_df[col]) + ['Generated']*len(generated_df[col])
|
||||
})
|
||||
|
||||
sns.boxplot(x='Dataset', y='Value', data=df_box, palette=['#1f77b4','#ff7f0e'])
|
||||
plt.title(f"Boxplot comparison for '{col_safe}'", fontsize=12, usetex=False)
|
||||
plt.tight_layout()
|
||||
box_path = os.path.join(self.temp_dir, f"{col}_box.png")
|
||||
plt.savefig(box_path, bbox_inches='tight')
|
||||
plt.close()
|
||||
box_img = Image.open(box_path)
|
||||
|
||||
vis_dict[col] = [hist_img, box_img]
|
||||
|
||||
return vis_dict
|
||||
@@ -0,0 +1,14 @@
|
||||
import hashlib
|
||||
import pandas as pd
|
||||
|
||||
def hash_row(row: pd.Series) -> str:
|
||||
"""Compute MD5 hash for a row to detect duplicates."""
|
||||
return hashlib.md5(str(tuple(row)).encode()).hexdigest()
|
||||
|
||||
|
||||
def sample_reference(reference_df: pd.DataFrame, n_reference_rows: int) -> list:
|
||||
"""Return a fresh sample of reference data for batch generation."""
|
||||
if reference_df is not None and not reference_df.empty:
|
||||
sample_df = reference_df.sample(min(n_reference_rows, len(reference_df)), replace=False)
|
||||
return sample_df.to_dict(orient="records")
|
||||
return []
|
||||
@@ -0,0 +1,112 @@
|
||||
import json
|
||||
import re
|
||||
import tempfile
|
||||
import openai
|
||||
import pandas as pd
|
||||
|
||||
import os
|
||||
from typing import List
|
||||
|
||||
|
||||
# ------------------ JSON Cleaning ------------------
|
||||
def _clean_json_output(raw_text: str) -> str:
|
||||
"""
|
||||
Cleans raw OpenAI output to produce valid JSON.
|
||||
Escapes only double quotes and control characters.
|
||||
"""
|
||||
text = raw_text.strip()
|
||||
text = re.sub(r"```(?:json)?", "", text)
|
||||
text = re.sub(r"</?[^>]+>", "", text)
|
||||
|
||||
def escape_quotes(match):
|
||||
value = match.group(1)
|
||||
value = value.replace('"', r"\"")
|
||||
value = value.replace("\n", r"\n").replace("\r", r"\r").replace("\t", r"\t")
|
||||
return f'"{value}"'
|
||||
|
||||
text = re.sub(r'"(.*?)"', escape_quotes, text)
|
||||
|
||||
if not text.startswith("["):
|
||||
text = "[" + text
|
||||
if not text.endswith("]"):
|
||||
text += "]"
|
||||
text = re.sub(r",\s*]", "]", text)
|
||||
return text
|
||||
|
||||
|
||||
# ------------------ Synthetic Data Generation ------------------
|
||||
def generate_synthetic_data_openai(
|
||||
system_prompt: str,
|
||||
full_user_prompt: str,
|
||||
openai_model: str = "gpt-4o-mini",
|
||||
max_tokens: int = 16000,
|
||||
temperature: float = 0.0,
|
||||
):
|
||||
"""
|
||||
Generates synthetic tabular data using OpenAI.
|
||||
Assumes `full_user_prompt` is already complete with reference data.
|
||||
"""
|
||||
response = openai.chat.completions.create(
|
||||
model=openai_model,
|
||||
messages=[
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": full_user_prompt},
|
||||
],
|
||||
max_completion_tokens=max_tokens,
|
||||
temperature=temperature,
|
||||
)
|
||||
|
||||
raw_text = response.choices[0].message.content
|
||||
cleaned_json = _clean_json_output(raw_text)
|
||||
|
||||
try:
|
||||
data = json.loads(cleaned_json)
|
||||
except json.JSONDecodeError as e:
|
||||
raise ValueError(
|
||||
f"Invalid JSON generated. Error: {e}\nTruncated output: {cleaned_json[:500]}"
|
||||
)
|
||||
|
||||
df = pd.DataFrame(data)
|
||||
|
||||
tmp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".csv")
|
||||
df.to_csv(tmp_file.name, index=False)
|
||||
tmp_file.close()
|
||||
|
||||
return df, tmp_file.name
|
||||
|
||||
# ----------------------Mini call to detect the number of rows in the prompt--------------
|
||||
def detect_total_rows_from_prompt(user_prompt: str, openai_model: str = "gpt-4o-mini") -> int:
|
||||
"""
|
||||
Detect the number of rows requested from the user prompt.
|
||||
Fallback to 20 if detection fails.
|
||||
"""
|
||||
mini_prompt = f"""
|
||||
Extract the number of rows to generate from this instruction:
|
||||
\"\"\"{user_prompt}\"\"\" Return only the number.
|
||||
"""
|
||||
openai.api_key = os.getenv("OPENAI_API_KEY")
|
||||
try:
|
||||
response = openai.chat.completions.create(
|
||||
model=openai_model,
|
||||
messages=[{"role": "user", "content": mini_prompt}],
|
||||
temperature=0,
|
||||
max_tokens=10,
|
||||
)
|
||||
text = response.choices[0].message.content.strip()
|
||||
total_rows = int("".join(filter(str.isdigit, text)))
|
||||
return max(total_rows, 1)
|
||||
except Exception:
|
||||
return 20
|
||||
|
||||
|
||||
# -------------- Function to generate synthetic data in a batch ---------------------
|
||||
def generate_batch(system_prompt: str, user_prompt: str, reference_sample: List[dict],
|
||||
batch_size: int, openai_model: str):
|
||||
"""Generate a single batch of synthetic data using OpenAI."""
|
||||
full_prompt = f"{user_prompt}\nSample: {reference_sample}\nGenerate exactly {batch_size} rows."
|
||||
df_batch, _ = generate_synthetic_data_openai(
|
||||
system_prompt=system_prompt,
|
||||
full_user_prompt=full_prompt,
|
||||
openai_model=openai_model,
|
||||
)
|
||||
return df_batch
|
||||
@@ -0,0 +1,13 @@
|
||||
import pandas as pd
|
||||
|
||||
# -------------------------------
|
||||
# Helper function to display CSV
|
||||
# -------------------------------
|
||||
def display_reference_csv(file):
|
||||
if file is None:
|
||||
return pd.DataFrame()
|
||||
try:
|
||||
df = pd.read_csv(file.name if hasattr(file, "name") else file)
|
||||
return df
|
||||
except Exception as e:
|
||||
return pd.DataFrame({"Error": [str(e)]})
|
||||
Reference in New Issue
Block a user