import os import json import datetime import time import hashlib from core import config import threading import re SAFETY_SETTINGS = [ {"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_NONE"}, {"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "BLOCK_NONE"}, {"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "BLOCK_NONE"}, {"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "BLOCK_NONE"}, ] # Thread-local storage for logging context _log_context = threading.local() # Cache for dynamic pricing from AI model selection PRICING_CACHE = {} # --- Token Estimation & Truncation Utilities --- def estimate_tokens(text): """Estimate token count using a 4-chars-per-token heuristic (no external libs required).""" if not text: return 0 return max(1, len(text) // 4) def truncate_to_tokens(text, max_tokens): """Truncate text to approximately max_tokens, keeping the most recent (tail) content.""" if not text: return text max_chars = max_tokens * 4 if len(text) <= max_chars: return text return text[-max_chars:] # --- In-Memory AI Response Cache --- _AI_CACHE = {} def get_ai_cache(key): """Retrieve a cached AI response by key. Returns None if not cached.""" return _AI_CACHE.get(key) def set_ai_cache(key, value): """Store an AI response in the in-memory cache keyed by a hash string.""" _AI_CACHE[key] = value def make_cache_key(prefix, *parts): """Build a stable MD5 cache key from a prefix and variable string parts.""" raw = "|".join(str(p) for p in parts) return f"{prefix}:{hashlib.md5(raw.encode('utf-8', errors='replace')).hexdigest()}" def set_log_file(filepath): _log_context.log_file = filepath def set_log_callback(callback): _log_context.callback = callback def set_progress_callback(callback): _log_context.progress_callback = callback def update_progress(percent): if getattr(_log_context, 'progress_callback', None): try: _log_context.progress_callback(percent) except: pass def clean_json(text): text = text.replace("```json", "").replace("```", "").strip() start_obj = text.find('{') start_arr = text.find('[') if start_obj == -1 and start_arr == -1: return text if start_obj != -1 and (start_arr == -1 or start_obj < start_arr): return text[start_obj:text.rfind('}')+1] else: return text[start_arr:text.rfind(']')+1] def sanitize_filename(name): if not name: return "Untitled" safe = "".join([c for c in name if c.isalnum() or c=='_']).replace(" ", "_") return safe if safe else "Untitled" def chapter_sort_key(ch): num = ch.get('num', 0) if isinstance(num, int): return num if isinstance(num, str) and num.isdigit(): return int(num) s = str(num).lower().strip() if 'prologue' in s: return -1 if 'epilogue' in s: return 9999 return 999 def get_sorted_book_folders(run_dir): if not os.path.exists(run_dir): return [] subdirs = [d for d in os.listdir(run_dir) if os.path.isdir(os.path.join(run_dir, d)) and d.startswith("Book_")] def sort_key(d): parts = d.split('_') if len(parts) > 1 and parts[1].isdigit(): return int(parts[1]) return 0 return sorted(subdirs, key=sort_key) def log_banner(phase, title): log(phase, f"{'─' * 18} {title} {'─' * 18}") def log(phase, msg): timestamp = datetime.datetime.now().strftime('%H:%M:%S') line = f"[{timestamp}] {phase:<15} | {msg}" print(line) if getattr(_log_context, 'log_file', None): with open(_log_context.log_file, "a", encoding="utf-8") as f: f.write(line + "\n") if getattr(_log_context, 'callback', None): try: _log_context.callback(phase, msg) except: pass def load_json(path): return json.load(open(path, 'r')) if os.path.exists(path) else None def create_default_personas(): if not os.path.exists(config.PERSONAS_DIR): os.makedirs(config.PERSONAS_DIR) if not os.path.exists(config.PERSONAS_FILE): try: with open(config.PERSONAS_FILE, 'w') as f: json.dump({}, f, indent=2) except: pass def get_length_presets(): presets = {} for k, v in config.LENGTH_DEFINITIONS.items(): presets[v['label']] = v return presets def log_image_attempt(folder, img_type, prompt, filename, status, error=None, score=None, critique=None): log_path = os.path.join(folder, "image_log.json") entry = { "timestamp": int(time.time()), "type": img_type, "prompt": prompt, "filename": filename, "status": status, "error": str(error) if error else None, "score": score, "critique": critique } data = [] if os.path.exists(log_path): try: with open(log_path, 'r') as f: data = json.load(f) except: pass data.append(entry) with open(log_path, 'w') as f: json.dump(data, f, indent=2) def get_run_folder(base_name): if not os.path.exists(base_name): os.makedirs(base_name) runs = [d for d in os.listdir(base_name) if d.startswith("run_")] next_num = max([int(r.split("_")[1]) for r in runs if r.split("_")[1].isdigit()] + [0]) + 1 folder = os.path.join(base_name, f"run_{next_num}") os.makedirs(folder) return folder def get_latest_run_folder(base_name): if not os.path.exists(base_name): return None runs = [d for d in os.listdir(base_name) if d.startswith("run_")] if not runs: return None runs.sort(key=lambda x: int(x.split('_')[1]) if x.split('_')[1].isdigit() else 0) return os.path.join(base_name, runs[-1]) def update_pricing(model_name, cost_str): if not model_name or not cost_str or cost_str == 'N/A': return try: in_cost = 0.0 out_cost = 0.0 prices = re.findall(r'(?:\$|USD)\s*([0-9]+\.?[0-9]*)', cost_str, re.IGNORECASE) if len(prices) >= 2: in_cost = float(prices[0]) out_cost = float(prices[1]) elif len(prices) == 1: in_cost = float(prices[0]) out_cost = in_cost * 3 if in_cost > 0: PRICING_CACHE[model_name] = {"input": in_cost, "output": out_cost} except: pass def calculate_cost(model_label, input_tokens, output_tokens, image_count=0): cost = 0.0 m = model_label.lower() if model_label in PRICING_CACHE: rates = PRICING_CACHE[model_label] cost = (input_tokens / 1_000_000 * rates['input']) + (output_tokens / 1_000_000 * rates['output']) elif 'imagen' in m or image_count > 0: cost = (image_count * 0.04) else: if 'flash' in m: cost = (input_tokens / 1_000_000 * 0.075) + (output_tokens / 1_000_000 * 0.30) elif 'pro' in m or 'logic' in m: cost = (input_tokens / 1_000_000 * 3.50) + (output_tokens / 1_000_000 * 10.50) return round(cost, 6) def log_usage(folder, model_label, usage_metadata=None, image_count=0): if not folder or not os.path.exists(folder): return log_path = os.path.join(folder, "usage_log.json") input_tokens = 0 output_tokens = 0 if usage_metadata: try: input_tokens = usage_metadata.prompt_token_count output_tokens = usage_metadata.candidates_token_count except: pass cost = calculate_cost(model_label, input_tokens, output_tokens, image_count) entry = { "timestamp": int(time.time()), "date": datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S'), "model": model_label, "input_tokens": input_tokens, "output_tokens": output_tokens, "images": image_count, "cost": round(cost, 6) } data = {"log": [], "totals": {"input_tokens": 0, "output_tokens": 0, "images": 0, "est_cost_usd": 0.0}} if os.path.exists(log_path): try: loaded = json.load(open(log_path, 'r')) if isinstance(loaded, list): data["log"] = loaded elif isinstance(loaded, dict): data = loaded except: pass data["log"].append(entry) t_in = sum(x.get('input_tokens', 0) for x in data["log"]) t_out = sum(x.get('output_tokens', 0) for x in data["log"]) t_img = sum(x.get('images', 0) for x in data["log"]) total_cost = 0.0 for x in data["log"]: if 'cost' in x: total_cost += x['cost'] else: c = 0.0 mx = x.get('model', '').lower() ix = x.get('input_tokens', 0) ox = x.get('output_tokens', 0) imgx = x.get('images', 0) if 'flash' in mx: c = (ix / 1_000_000 * 0.075) + (ox / 1_000_000 * 0.30) elif 'pro' in mx or 'logic' in mx: c = (ix / 1_000_000 * 3.50) + (ox / 1_000_000 * 10.50) elif 'imagen' in mx or imgx > 0: c = (imgx * 0.04) total_cost += c data["totals"] = { "input_tokens": t_in, "output_tokens": t_out, "images": t_img, "est_cost_usd": round(total_cost, 4) } with open(log_path, 'w') as f: json.dump(data, f, indent=2)