169 lines
4.9 KiB
Python
169 lines
4.9 KiB
Python
import json
|
|
import os
|
|
import sys
|
|
import base64
|
|
import subprocess
|
|
|
|
# Dependency Check
|
|
try:
|
|
import requests
|
|
except ImportError:
|
|
print("❌ Error: requests library not found.")
|
|
print("👉 Please run: pip install requests pillow")
|
|
sys.exit(1)
|
|
|
|
try:
|
|
from PIL import Image
|
|
except ImportError:
|
|
print("❌ Error: Pillow library not found.")
|
|
print("👉 Please run: pip install requests pillow")
|
|
sys.exit(1)
|
|
|
|
# ============================================================
|
|
# CONFIGURATION — Set your Gemini API key here
|
|
# ============================================================
|
|
GEMINI_API_KEY = "YOUR_GEMINI_API_KEY_HERE"
|
|
# ============================================================
|
|
|
|
# Paths
|
|
SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
|
|
BASE_DIR = os.path.dirname(SCRIPT_DIR)
|
|
PROMPTS_FILE = os.path.join(SCRIPT_DIR, "image_generation_prompts.json")
|
|
RAW_ASSETS_DIR = os.path.join(BASE_DIR, "Raw_Assets")
|
|
RESIZE_SCRIPT = os.path.join(SCRIPT_DIR, "resize_assets.py")
|
|
|
|
# Gemini Imagen API endpoint
|
|
IMAGEN_MODEL = "imagen-3.0-generate-001"
|
|
IMAGEN_URL = (
|
|
f"https://generativelanguage.googleapis.com/v1beta/models/"
|
|
f"{IMAGEN_MODEL}:predict?key={GEMINI_API_KEY}"
|
|
)
|
|
|
|
|
|
def load_prompts():
|
|
"""Load image prompts from the JSON config file."""
|
|
if not os.path.exists(PROMPTS_FILE):
|
|
print(f"❌ Error: Prompts file not found: {PROMPTS_FILE}")
|
|
sys.exit(1)
|
|
with open(PROMPTS_FILE, "r", encoding="utf-8") as f:
|
|
data = json.load(f)
|
|
return data.get("images", [])
|
|
|
|
|
|
def generate_image(name, prompt):
|
|
"""Call the Gemini Imagen API and return raw PNG bytes, or None on failure."""
|
|
payload = {
|
|
"instances": [{"prompt": prompt}],
|
|
"parameters": {"sampleCount": 1},
|
|
}
|
|
try:
|
|
response = requests.post(IMAGEN_URL, json=payload, timeout=120)
|
|
response.raise_for_status()
|
|
except requests.exceptions.HTTPError as e:
|
|
print(f" ❌ HTTP error for '{name}': {e}")
|
|
print(f" Response: {response.text[:300]}")
|
|
return None
|
|
except requests.exceptions.RequestException as e:
|
|
print(f" ❌ Request failed for '{name}': {e}")
|
|
return None
|
|
|
|
result = response.json()
|
|
predictions = result.get("predictions", [])
|
|
if not predictions:
|
|
print(f" ❌ No predictions returned for '{name}'.")
|
|
return None
|
|
|
|
b64_data = predictions[0].get("bytesBase64Encoded")
|
|
if not b64_data:
|
|
print(f" ❌ No image data in response for '{name}'.")
|
|
return None
|
|
|
|
return base64.b64decode(b64_data)
|
|
|
|
|
|
def save_image(name, image_bytes):
|
|
"""Save raw bytes as a PNG file in Raw_Assets/."""
|
|
output_path = os.path.join(RAW_ASSETS_DIR, f"{name}.png")
|
|
with open(output_path, "wb") as f:
|
|
f.write(image_bytes)
|
|
return output_path
|
|
|
|
|
|
def run_resize_script():
|
|
"""Run resize_assets.py as a subprocess."""
|
|
print("\n🔧 Running resize_assets.py...")
|
|
try:
|
|
result = subprocess.run(
|
|
[sys.executable, RESIZE_SCRIPT],
|
|
check=True,
|
|
capture_output=True,
|
|
text=True,
|
|
)
|
|
print(result.stdout)
|
|
if result.stderr:
|
|
print(result.stderr)
|
|
except subprocess.CalledProcessError as e:
|
|
print(f"❌ resize_assets.py failed: {e}")
|
|
if e.stdout:
|
|
print(e.stdout)
|
|
if e.stderr:
|
|
print(e.stderr)
|
|
|
|
|
|
def main():
|
|
print("🚀 IYmtg Image Generation Pipeline")
|
|
print("=" * 40)
|
|
|
|
if GEMINI_API_KEY == "YOUR_GEMINI_API_KEY_HERE":
|
|
print("❌ Error: Gemini API key not set.")
|
|
print("👉 Open this script and set GEMINI_API_KEY at the top.")
|
|
sys.exit(1)
|
|
|
|
# Ensure Raw_Assets directory exists
|
|
os.makedirs(RAW_ASSETS_DIR, exist_ok=True)
|
|
print(f"📂 Output directory: {RAW_ASSETS_DIR}\n")
|
|
|
|
prompts = load_prompts()
|
|
print(f"📋 Loaded {len(prompts)} image prompt(s) from {os.path.basename(PROMPTS_FILE)}\n")
|
|
|
|
generated = 0
|
|
failed = 0
|
|
|
|
for item in prompts:
|
|
name = item.get("name")
|
|
prompt = item.get("prompt")
|
|
|
|
if not name or not prompt:
|
|
print(f"⚠️ Skipping entry with missing 'name' or 'prompt': {item}")
|
|
continue
|
|
|
|
print(f"🎨 Generating: {name}")
|
|
print(f" Prompt: {prompt[:80]}{'...' if len(prompt) > 80 else ''}")
|
|
|
|
image_bytes = generate_image(name, prompt)
|
|
if image_bytes is None:
|
|
failed += 1
|
|
continue
|
|
|
|
try:
|
|
path = save_image(name, image_bytes)
|
|
print(f" ✅ Saved: {os.path.basename(path)}")
|
|
generated += 1
|
|
except OSError as e:
|
|
print(f" ❌ Failed to save '{name}': {e}")
|
|
failed += 1
|
|
|
|
print(f"\n{'=' * 40}")
|
|
print(f"✅ Generated: {generated} | ❌ Failed: {failed}")
|
|
|
|
if generated > 0:
|
|
run_resize_script()
|
|
else:
|
|
print("\n⚠️ No images were generated. Skipping resize step.")
|
|
|
|
print("\n✅ Done.")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|