338 lines
17 KiB
Python
338 lines
17 KiB
Python
import os
|
||
import json
|
||
import time
|
||
import warnings
|
||
import threading
|
||
import google.generativeai as genai
|
||
from core import config, utils
|
||
from ai import models
|
||
|
||
|
||
_LIST_MODELS_TIMEOUT = {"timeout": 30}
|
||
|
||
|
||
def get_optimal_model(base_type="pro"):
|
||
try:
|
||
available = [m for m in genai.list_models(request_options=_LIST_MODELS_TIMEOUT) if 'generateContent' in m.supported_generation_methods]
|
||
candidates = [m.name for m in available if base_type in m.name]
|
||
if not candidates: return f"models/gemini-1.5-{base_type}"
|
||
|
||
def score(n):
|
||
gen_bonus = 0
|
||
if "2.5" in n: gen_bonus = 300
|
||
elif "2.0" in n: gen_bonus = 200
|
||
elif "2." in n: gen_bonus = 150
|
||
if "exp" in n or "beta" in n or "preview" in n: return gen_bonus + 0
|
||
if "latest" in n: return gen_bonus + 50
|
||
return gen_bonus + 100
|
||
|
||
return sorted(candidates, key=score, reverse=True)[0]
|
||
except Exception as e:
|
||
utils.log("SYSTEM", f"⚠️ Error finding optimal model: {e}")
|
||
return f"models/gemini-1.5-{base_type}"
|
||
|
||
|
||
def get_default_models():
|
||
return {
|
||
"logic": {"model": "models/gemini-2.0-pro-exp", "reason": "Fallback: Gemini 2.0 Pro Exp (free) for cost-effective logic and JSON adherence.", "estimated_cost": "Free"},
|
||
"writer": {"model": "models/gemini-2.0-flash", "reason": "Fallback: Gemini 2.0 Flash for fast, high-quality creative writing.", "estimated_cost": "$0.10/1M"},
|
||
"artist": {"model": "models/gemini-2.0-flash", "reason": "Fallback: Gemini 2.0 Flash for visual prompt design.", "estimated_cost": "$0.10/1M"},
|
||
"pro_rewrite": {"model": "models/gemini-2.0-pro-exp", "reason": "Fallback: Gemini 2.0 Pro Exp (free) for critical chapter rewrites.", "estimated_cost": "Free"},
|
||
"ranking": []
|
||
}
|
||
|
||
|
||
def select_best_models(force_refresh=False):
|
||
cache_path = os.path.join(config.DATA_DIR, "model_cache.json")
|
||
cached_models = None
|
||
|
||
if os.path.exists(cache_path):
|
||
try:
|
||
with open(cache_path, 'r') as f:
|
||
cached = json.load(f)
|
||
cached_models = cached.get('models', {})
|
||
if not force_refresh and time.time() - cached.get('timestamp', 0) < 86400:
|
||
m = cached_models
|
||
if isinstance(m.get('logic'), dict) and 'reason' in m['logic']:
|
||
utils.log("SYSTEM", "Using cached AI model selection (valid for 24h).")
|
||
return m
|
||
except Exception as e:
|
||
utils.log("SYSTEM", f"Cache read failed: {e}. Refreshing models.")
|
||
|
||
try:
|
||
utils.log("SYSTEM", "Refreshing AI model list from API...")
|
||
all_models = list(genai.list_models(request_options=_LIST_MODELS_TIMEOUT))
|
||
raw_model_names = [m.name for m in all_models]
|
||
utils.log("SYSTEM", f"Found {len(all_models)} raw models from Google API.")
|
||
|
||
compatible = [m.name for m in all_models if 'generateContent' in m.supported_generation_methods and 'gemini' in m.name.lower()]
|
||
utils.log("SYSTEM", f"Identified {len(compatible)} compatible Gemini models: {compatible}")
|
||
|
||
bootstrapper = get_optimal_model("flash")
|
||
utils.log("SYSTEM", f"Bootstrapping model selection with: {bootstrapper}")
|
||
|
||
model = genai.GenerativeModel(bootstrapper)
|
||
prompt = f"""
|
||
ROLE: AI Model Architect
|
||
TASK: Select the optimal Gemini models for a book-writing application.
|
||
PRIMARY OBJECTIVE: Keep total book generation cost under $2.00. Quality is secondary to this budget.
|
||
|
||
AVAILABLE_MODELS:
|
||
{json.dumps(compatible)}
|
||
|
||
PRICING_CONTEXT (USD per 1M tokens — use these to calculate actual book cost):
|
||
- FREE TIER: Any model with 'exp', 'beta', or 'preview' in name = $0.00. Always prefer these.
|
||
e.g. gemini-2.0-pro-exp = FREE, gemini-2.5-pro-preview = FREE.
|
||
- gemini-2.5-flash / gemini-2.5-flash-preview: ~$0.075 Input / $0.30 Output.
|
||
- gemini-2.0-flash: ~$0.10 Input / $0.40 Output.
|
||
- gemini-1.5-flash: ~$0.075 Input / $0.30 Output.
|
||
- gemini-2.5-pro (stable, non-preview): ~$1.25 Input / $10.00 Output. BUDGET BREAKER.
|
||
- gemini-1.5-pro (stable): ~$1.25 Input / $5.00 Output. BUDGET BREAKER.
|
||
|
||
BOOK TOKEN BUDGET (30-chapter novel — use this to calculate real cost before deciding):
|
||
Logic role total: ~265,000 input tokens + ~55,000 output tokens
|
||
(planning, state tracking, consistency checks, director treatments per chapter)
|
||
Writer role total: ~450,000 input tokens + ~135,000 output tokens
|
||
(drafting, evaluation, refinement per chapter — 2 passes max)
|
||
Artist role total: ~30,000 input tokens + ~8,000 output tokens
|
||
(cover art prompt design, cover layout, blurb, image quality evaluation — text calls only)
|
||
|
||
NOTE: Cover IMAGE generation uses the Imagen API (billed per image, not per token).
|
||
Imagen costs are fixed at ~$0.04/image × up to 3 attempts = ~$0.12 max. This is SEPARATE
|
||
from the text token budget below and cannot be reduced by model selection.
|
||
|
||
COST FORMULA: cost = (input_tokens / 1,000,000 * input_price) + (output_tokens / 1,000,000 * output_price)
|
||
HARD BUDGET: Logic_cost + Writer_cost + Artist_cost (text only) must be < $1.85
|
||
(leaving $0.15 headroom for Imagen cover generation, total book target: $2.00).
|
||
|
||
SELECTION RULES (apply in order):
|
||
1. FREE FIRST: If a free/exp model exists (any tier, any quality), pick it for Logic. Cost = $0.
|
||
2. FLASH FOR WRITER: Flash is sufficient for fiction prose. Never pick a paid Pro for Writer.
|
||
3. CALCULATE: For non-free models, compute the actual book cost using the token budget above.
|
||
Reject any combination that exceeds $2.00 total.
|
||
4. QUALITY TIEBREAK: Among models with similar cost, prefer newer generation (2.x > 1.5).
|
||
5. NO THINKING MODELS: Too slow and expensive for any role.
|
||
|
||
ROLES:
|
||
- LOGIC: Planning, JSON adherence, plot consistency. Free/exp Pro ideal; Flash acceptable.
|
||
- WRITER: Creative prose, chapter drafting. Flash 2.x is sufficient — do NOT use paid Pro.
|
||
- ARTIST: Visual prompts for cover art. Cheapest capable Flash model.
|
||
- PRO_REWRITE: Emergency full-chapter rewrite (rare, ~1-2x per book). Best free/exp Pro available.
|
||
If no free Pro exists, use best Flash — do not use paid Pro even here.
|
||
|
||
OUTPUT_FORMAT (JSON only, no markdown):
|
||
{{
|
||
"logic": {{ "model": "string", "reason": "string", "estimated_cost": "$X.XX/1M", "book_cost": "$X.XX" }},
|
||
"writer": {{ "model": "string", "reason": "string", "estimated_cost": "$X.XX/1M", "book_cost": "$X.XX" }},
|
||
"artist": {{ "model": "string", "reason": "string", "estimated_cost": "$X.XX/1M", "book_cost": "$X.XX" }},
|
||
"pro_rewrite": {{ "model": "string", "reason": "string", "estimated_cost": "$X.XX/1M", "book_cost": "$X.XX" }},
|
||
"total_estimated_book_cost": "$X.XX",
|
||
"ranking": [ {{ "model": "string", "reason": "string", "estimated_cost": "string" }} ]
|
||
}}
|
||
"""
|
||
|
||
try:
|
||
response = model.generate_content(prompt)
|
||
selection = json.loads(utils.clean_json(response.text))
|
||
except Exception as e:
|
||
utils.log("SYSTEM", f"Model selection generation failed (Safety/Format): {e}")
|
||
raise e
|
||
|
||
if not os.path.exists(config.DATA_DIR): os.makedirs(config.DATA_DIR)
|
||
with open(cache_path, 'w') as f:
|
||
json.dump({
|
||
"timestamp": int(time.time()),
|
||
"models": selection,
|
||
"available_at_time": compatible,
|
||
"raw_models": raw_model_names
|
||
}, f, indent=2)
|
||
return selection
|
||
|
||
except Exception as e:
|
||
utils.log("SYSTEM", f"AI Model Selection failed: {e}.")
|
||
|
||
if cached_models:
|
||
utils.log("SYSTEM", "⚠️ Using stale cached models due to API failure.")
|
||
return cached_models
|
||
|
||
utils.log("SYSTEM", "Falling back to heuristics.")
|
||
fallback = get_default_models()
|
||
|
||
try:
|
||
with open(cache_path, 'w') as f:
|
||
json.dump({"timestamp": int(time.time()), "models": fallback, "error": str(e)}, f, indent=2)
|
||
except: pass
|
||
return fallback
|
||
|
||
|
||
def init_models(force=False):
|
||
global_vars = models.__dict__
|
||
if global_vars.get('model_logic') and not force: return
|
||
|
||
genai.configure(api_key=config.API_KEY)
|
||
|
||
cache_path = os.path.join(config.DATA_DIR, "model_cache.json")
|
||
skip_validation = False
|
||
if not force and os.path.exists(cache_path):
|
||
try:
|
||
with open(cache_path, 'r') as f: cached = json.load(f)
|
||
if time.time() - cached.get('timestamp', 0) < 86400: skip_validation = True
|
||
except: pass
|
||
|
||
if not skip_validation:
|
||
utils.log("SYSTEM", "Validating credentials...")
|
||
try:
|
||
list(genai.list_models(page_size=1, request_options=_LIST_MODELS_TIMEOUT))
|
||
utils.log("SYSTEM", "✅ Gemini API Key is valid.")
|
||
except Exception as e:
|
||
if os.path.exists(cache_path):
|
||
utils.log("SYSTEM", f"⚠️ API check failed ({e}), but cache exists. Attempting to use cached models.")
|
||
else:
|
||
utils.log("SYSTEM", f"⚠️ API check failed ({e}). No cache found. Attempting to initialize with defaults.")
|
||
|
||
utils.log("SYSTEM", "Selecting optimal models via AI...")
|
||
selected_models = select_best_models(force_refresh=force)
|
||
|
||
if not force:
|
||
missing_costs = False
|
||
for role in ['logic', 'writer', 'artist']:
|
||
role_data = selected_models.get(role, {})
|
||
if 'estimated_cost' not in role_data or role_data.get('estimated_cost') == 'N/A':
|
||
missing_costs = True
|
||
if 'book_cost' not in role_data:
|
||
missing_costs = True
|
||
if 'total_estimated_book_cost' not in selected_models:
|
||
missing_costs = True
|
||
if missing_costs:
|
||
utils.log("SYSTEM", "⚠️ Missing cost info in cached models. Forcing refresh.")
|
||
return init_models(force=True)
|
||
|
||
def get_model_details(role_data):
|
||
if isinstance(role_data, dict):
|
||
return role_data.get('model'), role_data.get('estimated_cost', 'N/A'), role_data.get('book_cost', 'N/A')
|
||
return role_data, 'N/A', 'N/A'
|
||
|
||
logic_name, logic_cost, logic_book = get_model_details(selected_models['logic'])
|
||
writer_name, writer_cost, writer_book = get_model_details(selected_models['writer'])
|
||
artist_name, artist_cost, artist_book = get_model_details(selected_models['artist'])
|
||
pro_name, pro_cost, _ = get_model_details(selected_models.get('pro_rewrite', {'model': 'models/gemini-2.0-pro-exp', 'estimated_cost': 'Free', 'book_cost': '$0.00'}))
|
||
total_book_cost = selected_models.get('total_estimated_book_cost', 'N/A')
|
||
|
||
logic_name = logic_name if config.MODEL_LOGIC_HINT == "AUTO" else config.MODEL_LOGIC_HINT
|
||
writer_name = writer_name if config.MODEL_WRITER_HINT == "AUTO" else config.MODEL_WRITER_HINT
|
||
artist_name = artist_name if config.MODEL_ARTIST_HINT == "AUTO" else config.MODEL_ARTIST_HINT
|
||
|
||
models.logic_model_name = logic_name
|
||
models.writer_model_name = writer_name
|
||
models.artist_model_name = artist_name
|
||
models.pro_model_name = pro_name
|
||
|
||
utils.log("SYSTEM", f"Models: Logic={logic_name} ({logic_cost}, {logic_book}/book) | Writer={writer_name} ({writer_cost}, {writer_book}/book) | Artist={artist_name} | Pro-Rewrite={pro_name} ({pro_cost})")
|
||
utils.log("SYSTEM", f"💰 Estimated book cost: {total_book_cost} text + ~$0.00-$0.12 Imagen cover (budget: $2.00 total)")
|
||
|
||
utils.update_pricing(logic_name, logic_cost)
|
||
utils.update_pricing(writer_name, writer_cost)
|
||
utils.update_pricing(artist_name, artist_cost)
|
||
|
||
if models.model_logic is None:
|
||
models.model_logic = models.ResilientModel(logic_name, utils.SAFETY_SETTINGS, "Logic")
|
||
models.model_writer = models.ResilientModel(writer_name, utils.SAFETY_SETTINGS, "Writer")
|
||
models.model_artist = models.ResilientModel(artist_name, utils.SAFETY_SETTINGS, "Artist")
|
||
else:
|
||
models.model_logic.update(logic_name)
|
||
models.model_writer.update(writer_name)
|
||
models.model_artist.update(artist_name)
|
||
|
||
models.model_image = None
|
||
models.image_model_name = None
|
||
models.image_model_source = "None"
|
||
|
||
hint = config.MODEL_IMAGE_HINT if hasattr(config, 'MODEL_IMAGE_HINT') else "AUTO"
|
||
|
||
if hasattr(genai, 'ImageGenerationModel'):
|
||
candidates = [hint] if hint and hint != "AUTO" else ["imagen-3.0-generate-001", "imagen-3.0-fast-generate-001"]
|
||
for candidate in candidates:
|
||
try:
|
||
models.model_image = genai.ImageGenerationModel(candidate)
|
||
models.image_model_name = candidate
|
||
models.image_model_source = "Gemini API"
|
||
utils.log("SYSTEM", f"✅ Image model: {candidate} (Gemini API)")
|
||
break
|
||
except Exception:
|
||
continue
|
||
|
||
# Auto-detect GCP Project
|
||
if models.HAS_VERTEX and not config.GCP_PROJECT and config.GOOGLE_CREDS and os.path.exists(config.GOOGLE_CREDS):
|
||
try:
|
||
with open(config.GOOGLE_CREDS, 'r') as f:
|
||
cdata = json.load(f)
|
||
for k in ['installed', 'web']:
|
||
if k in cdata and 'project_id' in cdata[k]:
|
||
config.GCP_PROJECT = cdata[k]['project_id']
|
||
utils.log("SYSTEM", f"Auto-detected GCP Project ID: {config.GCP_PROJECT}")
|
||
break
|
||
except: pass
|
||
|
||
if models.HAS_VERTEX and config.GCP_PROJECT:
|
||
creds = None
|
||
if models.HAS_OAUTH:
|
||
gac = config.GOOGLE_CREDS
|
||
if gac and os.path.exists(gac):
|
||
try:
|
||
with open(gac, 'r') as f: data = json.load(f)
|
||
if 'installed' in data or 'web' in data:
|
||
if "GOOGLE_APPLICATION_CREDENTIALS" in os.environ:
|
||
del os.environ["GOOGLE_APPLICATION_CREDENTIALS"]
|
||
|
||
token_path = os.path.join(os.path.dirname(os.path.abspath(gac)), 'token.json')
|
||
SCOPES = ['https://www.googleapis.com/auth/cloud-platform']
|
||
|
||
if os.path.exists(token_path):
|
||
creds = models.Credentials.from_authorized_user_file(token_path, SCOPES)
|
||
|
||
_is_headless = threading.current_thread() is not threading.main_thread()
|
||
|
||
if not creds or not creds.valid:
|
||
if creds and creds.expired and creds.refresh_token:
|
||
try:
|
||
creds.refresh(models.Request())
|
||
except Exception:
|
||
if _is_headless:
|
||
utils.log("SYSTEM", "⚠️ Token refresh failed and cannot re-authenticate in a background/headless thread. Vertex AI will use ADC or be unavailable.")
|
||
creds = None
|
||
else:
|
||
utils.log("SYSTEM", "Token refresh failed. Re-authenticating...")
|
||
flow = models.InstalledAppFlow.from_client_secrets_file(gac, SCOPES)
|
||
creds = flow.run_local_server(port=0)
|
||
else:
|
||
if _is_headless:
|
||
utils.log("SYSTEM", "⚠️ OAuth Client ID requires browser login but running in headless/background mode. Skipping interactive auth. Use a Service Account key for Vertex AI in background tasks.")
|
||
creds = None
|
||
else:
|
||
utils.log("SYSTEM", "OAuth Client ID detected. Launching browser to authenticate...")
|
||
flow = models.InstalledAppFlow.from_client_secrets_file(gac, SCOPES)
|
||
creds = flow.run_local_server(port=0)
|
||
if creds:
|
||
with open(token_path, 'w') as token: token.write(creds.to_json())
|
||
|
||
utils.log("SYSTEM", "✅ Authenticated via OAuth Client ID.")
|
||
except Exception as e:
|
||
utils.log("SYSTEM", f"⚠️ OAuth check failed: {e}")
|
||
|
||
import vertexai as _vertexai
|
||
_vertexai.init(project=config.GCP_PROJECT, location=config.GCP_LOCATION, credentials=creds)
|
||
utils.log("SYSTEM", f"✅ Vertex AI initialized (Project: {config.GCP_PROJECT})")
|
||
|
||
vertex_candidates = [hint] if hint and hint != "AUTO" else ["imagen-3.0-generate-001", "imagen-3.0-fast-generate-001"]
|
||
for candidate in vertex_candidates:
|
||
try:
|
||
models.model_image = models.VertexImageModel.from_pretrained(candidate)
|
||
models.image_model_name = candidate
|
||
models.image_model_source = "Vertex AI"
|
||
utils.log("SYSTEM", f"✅ Image model: {candidate} (Vertex AI)")
|
||
break
|
||
except Exception:
|
||
continue
|
||
|
||
utils.log("SYSTEM", f"Image Generation Provider: {models.image_model_source} ({models.image_model_name or 'unavailable'})")
|