Files
bookapp/modules/ai.py
2026-02-20 09:55:21 -05:00

334 lines
15 KiB
Python

import os
import json
import time
import warnings
import google.generativeai as genai
import config
from . import utils
# Suppress Vertex AI warnings
warnings.filterwarnings("ignore", category=UserWarning, module="vertexai")
try:
import vertexai
from vertexai.preview.vision_models import ImageGenerationModel as VertexImageModel
HAS_VERTEX = True
except ImportError:
HAS_VERTEX = False
try:
from google.auth.transport.requests import Request
from google.oauth2.credentials import Credentials
from google_auth_oauthlib.flow import InstalledAppFlow
HAS_OAUTH = True
except ImportError:
HAS_OAUTH = False
model_logic = None
model_writer = None
model_artist = None
model_image = None
logic_model_name = "models/gemini-1.5-pro"
writer_model_name = "models/gemini-1.5-flash"
artist_model_name = "models/gemini-1.5-flash"
class ResilientModel:
def __init__(self, name, safety_settings, role):
self.name = name
self.safety_settings = safety_settings
self.role = role
self.model = genai.GenerativeModel(name, safety_settings=safety_settings)
def update(self, name):
self.name = name
self.model = genai.GenerativeModel(name, safety_settings=self.safety_settings)
def generate_content(self, *args, **kwargs):
retries = 0
max_retries = 3
base_delay = 5
while True:
try:
return self.model.generate_content(*args, **kwargs)
except Exception as e:
err_str = str(e).lower()
is_retryable = "429" in err_str or "quota" in err_str or "500" in err_str or "503" in err_str or "504" in err_str or "deadline" in err_str or "internal error" in err_str
if is_retryable and retries < max_retries:
delay = base_delay * (2 ** retries)
utils.log("SYSTEM", f"⚠️ Quota error on {self.role} ({self.name}). Retrying in {delay}s...")
time.sleep(delay)
# On first retry, attempt to re-optimize/rotate models
if retries == 0:
utils.log("SYSTEM", "Attempting to re-optimize models to find alternative...")
init_models(force=True)
# Note: init_models calls .update() on this instance
retries += 1
continue
raise e
def get_optimal_model(base_type="pro"):
try:
models = [m for m in genai.list_models() if 'generateContent' in m.supported_generation_methods]
candidates = [m.name for m in models if base_type in m.name]
if not candidates: return f"models/gemini-1.5-{base_type}"
def score(n):
# Prioritize stable models (higher quotas) over experimental/beta ones
if "exp" in n or "beta" in n or "preview" in n: return 0
if "latest" in n: return 50
return 100
return sorted(candidates, key=score, reverse=True)[0]
except Exception as e:
utils.log("SYSTEM", f"⚠️ Error finding optimal model: {e}")
return f"models/gemini-1.5-{base_type}"
def get_default_models():
return {
"logic": {"model": "models/gemini-1.5-pro", "reason": "Fallback: Default Pro model selected.", "estimated_cost": "$3.50/1M"},
"writer": {"model": "models/gemini-1.5-flash", "reason": "Fallback: Default Flash model selected.", "estimated_cost": "$0.075/1M"},
"artist": {"model": "models/gemini-1.5-flash", "reason": "Fallback: Default Flash model selected.", "estimated_cost": "$0.075/1M"},
"ranking": []
}
def select_best_models(force_refresh=False):
"""
Uses a safe bootstrapper model to analyze available models and pick the best ones.
Caches the result for 24 hours.
"""
cache_path = os.path.join(config.DATA_DIR, "model_cache.json")
cached_models = None
# 1. Check Cache
if os.path.exists(cache_path):
try:
with open(cache_path, 'r') as f:
cached = json.load(f)
cached_models = cached.get('models', {})
# Check if within 24 hours (86400 seconds)
if not force_refresh and time.time() - cached.get('timestamp', 0) < 86400:
models = cached_models
# Validate format (must be dicts with reasons, not just strings)
if isinstance(models.get('logic'), dict) and 'reason' in models['logic']:
utils.log("SYSTEM", "Using cached AI model selection (valid for 24h).")
return models
except Exception as e:
utils.log("SYSTEM", f"Cache read failed: {e}. Refreshing models.")
try:
utils.log("SYSTEM", "Refreshing AI model list from API...")
all_models = list(genai.list_models())
raw_model_names = [m.name for m in all_models]
utils.log("SYSTEM", f"Found {len(all_models)} raw models from Google API.")
models = [m.name for m in all_models if 'generateContent' in m.supported_generation_methods and 'gemini' in m.name.lower()]
utils.log("SYSTEM", f"Identified {len(models)} compatible Gemini models: {models}")
bootstrapper = get_optimal_model("flash")
utils.log("SYSTEM", f"Bootstrapping model selection with: {bootstrapper}")
model = genai.GenerativeModel(bootstrapper)
prompt = f"""
ROLE: AI Model Architect
TASK: Select the optimal Gemini models for specific application roles.
AVAILABLE_MODELS:
{json.dumps(models)}
PRICING_CONTEXT (USD per 1M tokens):
- Flash Models (e.g. gemini-1.5-flash): ~$0.075 Input / $0.30 Output. (Very Cheap)
- Pro Models (e.g. gemini-1.5-pro): ~$3.50 Input / $10.50 Output. (Expensive)
CRITERIA:
- LOGIC: Needs complex reasoning, JSON adherence, and instruction following. (Prefer Pro/1.5).
- WRITER: Needs creativity, prose quality, and speed. (Prefer Flash/1.5 for speed, or Pro for quality).
- ARTIST: Needs visual prompt understanding.
CONSTRAINTS:
- Avoid 'experimental' or 'preview' unless no stable version exists.
- Prioritize 'latest' or stable versions.
OUTPUT_FORMAT (JSON):
{{
"logic": {{ "model": "string", "reason": "string", "estimated_cost": "$X.XX Input / $X.XX Output" }},
"writer": {{ "model": "string", "reason": "string", "estimated_cost": "$X.XX Input / $X.XX Output" }},
"artist": {{ "model": "string", "reason": "string", "estimated_cost": "$X.XX Input / $X.XX Output" }},
"ranking": [ {{ "model": "string", "reason": "string", "estimated_cost": "string" }} ]
}}
"""
try:
response = model.generate_content(prompt)
selection = json.loads(utils.clean_json(response.text))
except Exception as e:
utils.log("SYSTEM", f"Model selection generation failed (Safety/Format): {e}")
raise e
if not os.path.exists(config.DATA_DIR): os.makedirs(config.DATA_DIR)
with open(cache_path, 'w') as f:
json.dump({
"timestamp": int(time.time()),
"models": selection,
"available_at_time": models,
"raw_models": raw_model_names
}, f, indent=2)
return selection
except Exception as e:
utils.log("SYSTEM", f"AI Model Selection failed: {e}.")
# 3. Fallback to Stale Cache if available (Better than heuristics)
# Relaxed check: If we successfully loaded ANY JSON from the cache, use it.
if cached_models:
utils.log("SYSTEM", "⚠️ Using stale cached models due to API failure.")
return cached_models
utils.log("SYSTEM", "Falling back to heuristics.")
fallback = get_default_models()
# Save fallback to cache if file doesn't exist OR if we couldn't load it (corrupt/None)
# This ensures we have a valid file on disk for the web UI to read.
try:
with open(cache_path, 'w') as f:
json.dump({"timestamp": int(time.time()), "models": fallback, "error": str(e)}, f, indent=2)
except: pass
return fallback
def init_models(force=False):
global model_logic, model_writer, model_artist, model_image, logic_model_name, writer_model_name, artist_model_name
if model_logic and not force: return
genai.configure(api_key=config.API_KEY)
# Check cache to skip frequent validation
cache_path = os.path.join(config.DATA_DIR, "model_cache.json")
skip_validation = False
if not force and os.path.exists(cache_path):
try:
with open(cache_path, 'r') as f: cached = json.load(f)
if time.time() - cached.get('timestamp', 0) < 86400: skip_validation = True
except: pass
if not skip_validation:
# Validate Gemini API Key
utils.log("SYSTEM", "Validating credentials...")
try:
list(genai.list_models(page_size=1))
utils.log("SYSTEM", "✅ Gemini API Key is valid.")
except Exception as e:
# Check if we have a cache file we can rely on before exiting
if os.path.exists(cache_path):
utils.log("SYSTEM", f"⚠️ API check failed ({e}), but cache exists. Attempting to use cached models.")
else:
utils.log("SYSTEM", f"⚠️ API check failed ({e}). No cache found. Attempting to initialize with defaults.")
utils.log("SYSTEM", "Selecting optimal models via AI...")
selected_models = select_best_models(force_refresh=force)
# Check for missing costs and force refresh if needed
if not force:
missing_costs = False
for role in ['logic', 'writer', 'artist']:
if 'estimated_cost' not in selected_models.get(role, {}) or selected_models[role].get('estimated_cost') == 'N/A':
missing_costs = True
if missing_costs:
utils.log("SYSTEM", "⚠️ Missing cost info in cached models. Forcing refresh.")
return init_models(force=True)
def get_model_details(role_data):
if isinstance(role_data, dict): return role_data.get('model'), role_data.get('estimated_cost', 'N/A')
return role_data, 'N/A'
logic_name, logic_cost = get_model_details(selected_models['logic'])
writer_name, writer_cost = get_model_details(selected_models['writer'])
artist_name, artist_cost = get_model_details(selected_models['artist'])
logic_name = logic_model_name = logic_name if config.MODEL_LOGIC_HINT == "AUTO" else config.MODEL_LOGIC_HINT
writer_name = writer_model_name = writer_name if config.MODEL_WRITER_HINT == "AUTO" else config.MODEL_WRITER_HINT
artist_name = artist_model_name = artist_name if config.MODEL_ARTIST_HINT == "AUTO" else config.MODEL_ARTIST_HINT
utils.log("SYSTEM", f"Models: Logic={logic_name} ({logic_cost}) | Writer={writer_name} ({writer_cost}) | Artist={artist_name}")
# Update pricing in utils
utils.update_pricing(logic_name, logic_cost)
utils.update_pricing(writer_name, writer_cost)
utils.update_pricing(artist_name, artist_cost)
# Initialize or Update Resilient Models
if model_logic is None:
model_logic = ResilientModel(logic_name, utils.SAFETY_SETTINGS, "Logic")
model_writer = ResilientModel(writer_name, utils.SAFETY_SETTINGS, "Writer")
model_artist = ResilientModel(artist_name, utils.SAFETY_SETTINGS, "Artist")
else:
# If models already exist (re-init), update them in place
model_logic.update(logic_name)
model_writer.update(writer_name)
model_artist.update(artist_name)
# Initialize Image Model (Default to None)
model_image = None
if hasattr(genai, 'ImageGenerationModel'):
try: model_image = genai.ImageGenerationModel("imagen-3.0-generate-001")
except: pass
img_source = "Gemini API" if model_image else "None"
# Auto-detect GCP Project from credentials if not set (Fix for Image Model)
if HAS_VERTEX and not config.GCP_PROJECT and config.GOOGLE_CREDS and os.path.exists(config.GOOGLE_CREDS):
try:
with open(config.GOOGLE_CREDS, 'r') as f:
cdata = json.load(f)
# Check common OAuth structures
for k in ['installed', 'web']:
if k in cdata and 'project_id' in cdata[k]:
config.GCP_PROJECT = cdata[k]['project_id']
utils.log("SYSTEM", f"Auto-detected GCP Project ID: {config.GCP_PROJECT}")
break
except: pass
if HAS_VERTEX and config.GCP_PROJECT:
creds = None
# Handle OAuth Client ID (credentials.json) if provided instead of Service Account
if HAS_OAUTH:
gac = config.GOOGLE_CREDS # Use persistent config, not volatile env var
if gac and os.path.exists(gac):
try:
with open(gac, 'r') as f: data = json.load(f)
if 'installed' in data or 'web' in data:
# It's an OAuth Client ID. Unset env var to avoid library crash.
if "GOOGLE_APPLICATION_CREDENTIALS" in os.environ:
del os.environ["GOOGLE_APPLICATION_CREDENTIALS"]
token_path = os.path.join(os.path.dirname(os.path.abspath(gac)), 'token.json')
SCOPES = ['https://www.googleapis.com/auth/cloud-platform']
if os.path.exists(token_path):
creds = Credentials.from_authorized_user_file(token_path, SCOPES)
if not creds or not creds.valid:
if creds and creds.expired and creds.refresh_token:
try:
creds.refresh(Request())
except Exception:
utils.log("SYSTEM", "Token refresh failed. Re-authenticating...")
flow = InstalledAppFlow.from_client_secrets_file(gac, SCOPES)
creds = flow.run_local_server(port=0)
else:
utils.log("SYSTEM", "OAuth Client ID detected. Launching browser to authenticate...")
flow = InstalledAppFlow.from_client_secrets_file(gac, SCOPES)
creds = flow.run_local_server(port=0)
with open(token_path, 'w') as token: token.write(creds.to_json())
utils.log("SYSTEM", "✅ Authenticated via OAuth Client ID.")
except Exception as e:
utils.log("SYSTEM", f"⚠️ OAuth check failed: {e}")
vertexai.init(project=config.GCP_PROJECT, location=config.GCP_LOCATION, credentials=creds)
utils.log("SYSTEM", f"✅ Vertex AI initialized (Project: {config.GCP_PROJECT})")
# Override with Vertex Image Model if available
try:
model_image = VertexImageModel.from_pretrained("imagen-3.0-generate-001")
img_source = "Vertex AI"
except: pass
utils.log("SYSTEM", f"Image Generation Provider: {img_source}")