import os import json import time import warnings import google.generativeai as genai import config from . import utils # Suppress Vertex AI warnings warnings.filterwarnings("ignore", category=UserWarning, module="vertexai") try: import vertexai from vertexai.preview.vision_models import ImageGenerationModel as VertexImageModel HAS_VERTEX = True except ImportError: HAS_VERTEX = False try: from google.auth.transport.requests import Request from google.oauth2.credentials import Credentials from google_auth_oauthlib.flow import InstalledAppFlow HAS_OAUTH = True except ImportError: HAS_OAUTH = False model_logic = None model_writer = None model_artist = None model_image = None logic_model_name = "models/gemini-1.5-pro" writer_model_name = "models/gemini-1.5-flash" artist_model_name = "models/gemini-1.5-flash" class ResilientModel: def __init__(self, name, safety_settings, role): self.name = name self.safety_settings = safety_settings self.role = role self.model = genai.GenerativeModel(name, safety_settings=safety_settings) def update(self, name): self.name = name self.model = genai.GenerativeModel(name, safety_settings=self.safety_settings) def generate_content(self, *args, **kwargs): retries = 0 max_retries = 3 base_delay = 5 while True: try: return self.model.generate_content(*args, **kwargs) except Exception as e: is_quota = "429" in str(e) or "quota" in str(e).lower() if is_quota and retries < max_retries: delay = base_delay * (2 ** retries) utils.log("SYSTEM", f"⚠️ Quota error on {self.role} ({self.name}). Retrying in {delay}s...") time.sleep(delay) # On first retry, attempt to re-optimize/rotate models if retries == 0: utils.log("SYSTEM", "Attempting to re-optimize models to find alternative...") init_models(force=True) # Note: init_models calls .update() on this instance retries += 1 continue raise e def get_optimal_model(base_type="pro"): try: models = [m for m in genai.list_models() if 'generateContent' in m.supported_generation_methods] candidates = [m.name for m in models if base_type in m.name] if not candidates: return f"models/gemini-1.5-{base_type}" def score(n): # Prioritize stable models (higher quotas) over experimental/beta ones if "exp" in n or "beta" in n: return 0 if "latest" in n: return 50 return 100 return sorted(candidates, key=score, reverse=True)[0] except: return f"models/gemini-1.5-{base_type}" def get_default_models(): return { "logic": {"model": "models/gemini-1.5-pro", "reason": "Fallback: Default Pro model selected.", "estimated_cost": "$3.50/1M"}, "writer": {"model": "models/gemini-1.5-flash", "reason": "Fallback: Default Flash model selected.", "estimated_cost": "$0.075/1M"}, "artist": {"model": "models/gemini-1.5-flash", "reason": "Fallback: Default Flash model selected.", "estimated_cost": "$0.075/1M"}, "ranking": [] } def select_best_models(force_refresh=False): """ Uses a safe bootstrapper model to analyze available models and pick the best ones. Caches the result for 24 hours. """ cache_path = os.path.join(config.DATA_DIR, "model_cache.json") cached_models = None # 1. Check Cache if os.path.exists(cache_path): try: with open(cache_path, 'r') as f: cached = json.load(f) cached_models = cached.get('models', {}) # Check if within 24 hours (86400 seconds) if not force_refresh and time.time() - cached.get('timestamp', 0) < 86400: models = cached_models # Validate format (must be dicts with reasons, not just strings) if isinstance(models.get('logic'), dict) and 'reason' in models['logic']: utils.log("SYSTEM", "Using cached AI model selection (valid for 24h).") return models except Exception as e: utils.log("SYSTEM", f"Cache read failed: {e}. Refreshing models.") try: utils.log("SYSTEM", "Refreshing AI model list from API...") models = [m.name for m in genai.list_models() if 'generateContent' in m.supported_generation_methods and 'gemini' in m.name.lower()] bootstrapper = get_optimal_model("flash") utils.log("SYSTEM", f"Bootstrapping model selection with: {bootstrapper}") model = genai.GenerativeModel(bootstrapper) prompt = f""" ROLE: AI Model Architect TASK: Select the optimal Gemini models for specific application roles. AVAILABLE_MODELS: {json.dumps(models)} PRICING_CONTEXT (USD per 1M tokens): - Flash Models (e.g. gemini-1.5-flash): ~$0.075 Input / $0.30 Output. (Very Cheap) - Pro Models (e.g. gemini-1.5-pro): ~$3.50 Input / $10.50 Output. (Expensive) CRITERIA: - LOGIC: Needs complex reasoning, JSON adherence, and instruction following. (Prefer Pro/1.5). - WRITER: Needs creativity, prose quality, and speed. (Prefer Flash/1.5 for speed, or Pro for quality). - ARTIST: Needs visual prompt understanding. CONSTRAINTS: - Avoid 'experimental' unless no stable version exists. - Prioritize 'latest' or stable versions. OUTPUT_FORMAT (JSON): {{ "logic": {{ "model": "string", "reason": "string", "estimated_cost": "$X.XX Input / $X.XX Output" }}, "writer": {{ "model": "string", "reason": "string", "estimated_cost": "$X.XX Input / $X.XX Output" }}, "artist": {{ "model": "string", "reason": "string", "estimated_cost": "$X.XX Input / $X.XX Output" }}, "ranking": [ {{ "model": "string", "reason": "string", "estimated_cost": "string" }} ] }} """ try: response = model.generate_content(prompt) selection = json.loads(utils.clean_json(response.text)) except Exception as e: utils.log("SYSTEM", f"Model selection generation failed (Safety/Format): {e}") raise e if not os.path.exists(config.DATA_DIR): os.makedirs(config.DATA_DIR) with open(cache_path, 'w') as f: json.dump({"timestamp": int(time.time()), "models": selection, "available_at_time": models}, f, indent=2) return selection except Exception as e: utils.log("SYSTEM", f"AI Model Selection failed: {e}.") # 3. Fallback to Stale Cache if available (Better than heuristics) # Relaxed check: If we successfully loaded ANY JSON from the cache, use it. if cached_models: utils.log("SYSTEM", "⚠️ Using stale cached models due to API failure.") return cached_models utils.log("SYSTEM", "Falling back to heuristics.") fallback = get_default_models() # Save fallback to cache if file doesn't exist OR if we couldn't load it (corrupt/None) # This ensures we have a valid file on disk for the web UI to read. try: with open(cache_path, 'w') as f: json.dump({"timestamp": int(time.time()), "models": fallback, "error": str(e)}, f, indent=2) except: pass return fallback def init_models(force=False): global model_logic, model_writer, model_artist, model_image, logic_model_name, writer_model_name, artist_model_name if model_logic and not force: return genai.configure(api_key=config.API_KEY) # Check cache to skip frequent validation cache_path = os.path.join(config.DATA_DIR, "model_cache.json") skip_validation = False if not force and os.path.exists(cache_path): try: with open(cache_path, 'r') as f: cached = json.load(f) if time.time() - cached.get('timestamp', 0) < 86400: skip_validation = True except: pass if not skip_validation: # Validate Gemini API Key utils.log("SYSTEM", "Validating credentials...") try: list(genai.list_models(page_size=1)) utils.log("SYSTEM", "✅ Gemini API Key is valid.") except Exception as e: # Check if we have a cache file we can rely on before exiting if os.path.exists(cache_path): utils.log("SYSTEM", f"⚠️ API check failed ({e}), but cache exists. Attempting to use cached models.") else: utils.log("SYSTEM", f"⚠️ API check failed ({e}). No cache found. Attempting to initialize with defaults.") utils.log("SYSTEM", "Selecting optimal models via AI...") selected_models = select_best_models(force_refresh=force) # Check for missing costs and force refresh if needed if not force: missing_costs = False for role in ['logic', 'writer', 'artist']: if 'estimated_cost' not in selected_models.get(role, {}) or selected_models[role].get('estimated_cost') == 'N/A': missing_costs = True if missing_costs: utils.log("SYSTEM", "⚠️ Missing cost info in cached models. Forcing refresh.") return init_models(force=True) def get_model_details(role_data): if isinstance(role_data, dict): return role_data.get('model'), role_data.get('estimated_cost', 'N/A') return role_data, 'N/A' logic_name, logic_cost = get_model_details(selected_models['logic']) writer_name, writer_cost = get_model_details(selected_models['writer']) artist_name, artist_cost = get_model_details(selected_models['artist']) logic_name = logic_model_name = logic_name if config.MODEL_LOGIC_HINT == "AUTO" else config.MODEL_LOGIC_HINT writer_name = writer_model_name = writer_name if config.MODEL_WRITER_HINT == "AUTO" else config.MODEL_WRITER_HINT artist_name = artist_model_name = artist_name if config.MODEL_ARTIST_HINT == "AUTO" else config.MODEL_ARTIST_HINT utils.log("SYSTEM", f"Models: Logic={logic_name} ({logic_cost}) | Writer={writer_name} ({writer_cost}) | Artist={artist_name}") # Update pricing in utils utils.update_pricing(logic_name, logic_cost) utils.update_pricing(writer_name, writer_cost) utils.update_pricing(artist_name, artist_cost) # Initialize or Update Resilient Models if model_logic is None: model_logic = ResilientModel(logic_name, utils.SAFETY_SETTINGS, "Logic") model_writer = ResilientModel(writer_name, utils.SAFETY_SETTINGS, "Writer") model_artist = ResilientModel(artist_name, utils.SAFETY_SETTINGS, "Artist") else: # If models already exist (re-init), update them in place model_logic.update(logic_name) model_writer.update(writer_name) model_artist.update(artist_name) # Initialize Image Model (Default to None) model_image = None if hasattr(genai, 'ImageGenerationModel'): try: model_image = genai.ImageGenerationModel("imagen-3.0-generate-001") except: pass img_source = "Gemini API" if model_image else "None" # Auto-detect GCP Project from credentials if not set (Fix for Image Model) if HAS_VERTEX and not config.GCP_PROJECT and config.GOOGLE_CREDS and os.path.exists(config.GOOGLE_CREDS): try: with open(config.GOOGLE_CREDS, 'r') as f: cdata = json.load(f) # Check common OAuth structures for k in ['installed', 'web']: if k in cdata and 'project_id' in cdata[k]: config.GCP_PROJECT = cdata[k]['project_id'] utils.log("SYSTEM", f"Auto-detected GCP Project ID: {config.GCP_PROJECT}") break except: pass if HAS_VERTEX and config.GCP_PROJECT: creds = None # Handle OAuth Client ID (credentials.json) if provided instead of Service Account if HAS_OAUTH: gac = config.GOOGLE_CREDS # Use persistent config, not volatile env var if gac and os.path.exists(gac): try: with open(gac, 'r') as f: data = json.load(f) if 'installed' in data or 'web' in data: # It's an OAuth Client ID. Unset env var to avoid library crash. if "GOOGLE_APPLICATION_CREDENTIALS" in os.environ: del os.environ["GOOGLE_APPLICATION_CREDENTIALS"] token_path = os.path.join(os.path.dirname(os.path.abspath(gac)), 'token.json') SCOPES = ['https://www.googleapis.com/auth/cloud-platform'] if os.path.exists(token_path): creds = Credentials.from_authorized_user_file(token_path, SCOPES) if not creds or not creds.valid: if creds and creds.expired and creds.refresh_token: try: creds.refresh(Request()) except Exception: utils.log("SYSTEM", "Token refresh failed. Re-authenticating...") flow = InstalledAppFlow.from_client_secrets_file(gac, SCOPES) creds = flow.run_local_server(port=0) else: utils.log("SYSTEM", "OAuth Client ID detected. Launching browser to authenticate...") flow = InstalledAppFlow.from_client_secrets_file(gac, SCOPES) creds = flow.run_local_server(port=0) with open(token_path, 'w') as token: token.write(creds.to_json()) utils.log("SYSTEM", "✅ Authenticated via OAuth Client ID.") except Exception as e: utils.log("SYSTEM", f"⚠️ OAuth check failed: {e}") vertexai.init(project=config.GCP_PROJECT, location=config.GCP_LOCATION, credentials=creds) utils.log("SYSTEM", f"✅ Vertex AI initialized (Project: {config.GCP_PROJECT})") # Override with Vertex Image Model if available try: model_image = VertexImageModel.from_pretrained("imagen-3.0-generate-001") img_source = "Vertex AI" except: pass utils.log("SYSTEM", f"Image Generation Provider: {img_source}")