Adding files.
This commit is contained in:
215
modules/ai.py
Normal file
215
modules/ai.py
Normal file
@@ -0,0 +1,215 @@
|
||||
import os
|
||||
import sys
|
||||
import json
|
||||
import time
|
||||
import warnings
|
||||
import google.generativeai as genai
|
||||
import config
|
||||
from . import utils
|
||||
|
||||
# Suppress Vertex AI warnings
|
||||
warnings.filterwarnings("ignore", category=UserWarning, module="vertexai")
|
||||
|
||||
try:
|
||||
import vertexai
|
||||
from vertexai.preview.vision_models import ImageGenerationModel as VertexImageModel
|
||||
HAS_VERTEX = True
|
||||
except ImportError:
|
||||
HAS_VERTEX = False
|
||||
|
||||
try:
|
||||
from google.auth.transport.requests import Request
|
||||
from google.oauth2.credentials import Credentials
|
||||
from google_auth_oauthlib.flow import InstalledAppFlow
|
||||
HAS_OAUTH = True
|
||||
except ImportError:
|
||||
HAS_OAUTH = False
|
||||
|
||||
model_logic = None
|
||||
model_writer = None
|
||||
model_artist = None
|
||||
model_image = None
|
||||
|
||||
def get_optimal_model(base_type="pro"):
|
||||
try:
|
||||
models = [m for m in genai.list_models() if 'generateContent' in m.supported_generation_methods]
|
||||
candidates = [m.name for m in models if base_type in m.name]
|
||||
if not candidates: return f"models/gemini-1.5-{base_type}"
|
||||
def score(n):
|
||||
# Prioritize stable models (higher quotas) over experimental/beta ones
|
||||
if "exp" in n or "beta" in n: return 0
|
||||
if "latest" in n: return 50
|
||||
return 100
|
||||
return sorted(candidates, key=score, reverse=True)[0]
|
||||
except: return f"models/gemini-1.5-{base_type}"
|
||||
|
||||
def get_default_models():
|
||||
return {
|
||||
"logic": {"model": "models/gemini-1.5-pro", "reason": "Fallback: Default Pro model selected."},
|
||||
"writer": {"model": "models/gemini-1.5-flash", "reason": "Fallback: Default Flash model selected."},
|
||||
"artist": {"model": "models/gemini-1.5-flash", "reason": "Fallback: Default Flash model selected."},
|
||||
"ranking": []
|
||||
}
|
||||
|
||||
def select_best_models(force_refresh=False):
|
||||
"""
|
||||
Uses a safe bootstrapper model to analyze available models and pick the best ones.
|
||||
Caches the result for 24 hours.
|
||||
"""
|
||||
cache_path = os.path.join(config.DATA_DIR, "model_cache.json")
|
||||
cached_models = None
|
||||
|
||||
# 1. Check Cache
|
||||
if os.path.exists(cache_path):
|
||||
try:
|
||||
with open(cache_path, 'r') as f:
|
||||
cached = json.load(f)
|
||||
cached_models = cached.get('models', {})
|
||||
# Check if within 24 hours (86400 seconds)
|
||||
if not force_refresh and time.time() - cached.get('timestamp', 0) < 86400:
|
||||
models = cached_models
|
||||
# Validate format (must be dicts with reasons, not just strings)
|
||||
if isinstance(models.get('logic'), dict) and 'reason' in models['logic']:
|
||||
utils.log("SYSTEM", "Using cached AI model selection (valid for 24h).")
|
||||
return models
|
||||
except Exception as e:
|
||||
utils.log("SYSTEM", f"Cache read failed: {e}. Refreshing models.")
|
||||
|
||||
try:
|
||||
utils.log("SYSTEM", "Refreshing AI model list from API...")
|
||||
models = [m.name for m in genai.list_models() if 'generateContent' in m.supported_generation_methods]
|
||||
|
||||
bootstrapper = "models/gemini-1.5-flash"
|
||||
if bootstrapper not in models:
|
||||
candidates = [m for m in models if 'flash' in m]
|
||||
bootstrapper = candidates[0] if candidates else "models/gemini-pro"
|
||||
utils.log("SYSTEM", f"Bootstrapping model selection with: {bootstrapper}")
|
||||
|
||||
model = genai.GenerativeModel(bootstrapper)
|
||||
prompt = f"Analyze this list of available Google Gemini models:\n{json.dumps(models)}\n\nSelect the best model for each of these three roles based on these criteria:\n- Most recent version with best features and ability.\n- Beta versions are okay, but avoid 'experimental' if a stable beta/prod version exists.\n- Consider quota efficiency (Flash is cheaper/faster, Pro is smarter).\n\nROLES:\n1. LOGIC: For complex reasoning, JSON structuring, and plot planning.\n2. WRITER: For creative fiction writing, prose generation, and speed.\n3. ARTIST: For generating visual art prompts and design instructions.\n\nAlso provide a 'ranking' list of ALL models analyzed, ordered from best/most useful to worst/least useful, with a short reason.\n\nReturn JSON: {{ 'logic': {{ 'model': 'model_name', 'reason': 'reasoning' }}, 'writer': {{ 'model': 'model_name', 'reason': 'reasoning' }}, 'artist': {{ 'model': 'model_name', 'reason': 'reasoning' }}, 'ranking': [ {{ 'model': 'model_name', 'reason': 'reasoning' }} ] }}"
|
||||
|
||||
response = model.generate_content(prompt)
|
||||
selection = json.loads(utils.clean_json(response.text))
|
||||
|
||||
if not os.path.exists(config.DATA_DIR): os.makedirs(config.DATA_DIR)
|
||||
with open(cache_path, 'w') as f:
|
||||
json.dump({"timestamp": int(time.time()), "models": selection, "available_at_time": models}, f, indent=2)
|
||||
return selection
|
||||
except Exception as e:
|
||||
utils.log("SYSTEM", f"AI Model Selection failed: {e}.")
|
||||
|
||||
# 3. Fallback to Stale Cache if available (Better than heuristics)
|
||||
# Relaxed check: If we successfully loaded ANY JSON from the cache, use it.
|
||||
if cached_models:
|
||||
utils.log("SYSTEM", "⚠️ Using stale cached models due to API failure.")
|
||||
return cached_models
|
||||
|
||||
utils.log("SYSTEM", "Falling back to heuristics.")
|
||||
fallback = get_default_models()
|
||||
|
||||
# Save fallback to cache if file doesn't exist OR if we couldn't load it (corrupt/None)
|
||||
# This ensures we have a valid file on disk for the web UI to read.
|
||||
try:
|
||||
with open(cache_path, 'w') as f:
|
||||
json.dump({"timestamp": int(time.time()), "models": fallback, "error": str(e)}, f, indent=2)
|
||||
except: pass
|
||||
return fallback
|
||||
|
||||
def init_models(force=False):
|
||||
global model_logic, model_writer, model_artist, model_image
|
||||
if model_logic and not force: return
|
||||
genai.configure(api_key=config.API_KEY)
|
||||
|
||||
# Check cache to skip frequent validation
|
||||
cache_path = os.path.join(config.DATA_DIR, "model_cache.json")
|
||||
skip_validation = False
|
||||
if not force and os.path.exists(cache_path):
|
||||
try:
|
||||
with open(cache_path, 'r') as f: cached = json.load(f)
|
||||
if time.time() - cached.get('timestamp', 0) < 86400: skip_validation = True
|
||||
except: pass
|
||||
|
||||
if not skip_validation:
|
||||
# Validate Gemini API Key
|
||||
utils.log("SYSTEM", "Validating credentials...")
|
||||
try:
|
||||
list(genai.list_models(page_size=1))
|
||||
utils.log("SYSTEM", "✅ Gemini API Key is valid.")
|
||||
except Exception as e:
|
||||
# Check if we have a cache file we can rely on before exiting
|
||||
if os.path.exists(cache_path):
|
||||
utils.log("SYSTEM", f"⚠️ API check failed ({e}), but cache exists. Attempting to use cached models.")
|
||||
else:
|
||||
utils.log("SYSTEM", f"⚠️ API check failed ({e}). No cache found. Attempting to initialize with defaults.")
|
||||
|
||||
utils.log("SYSTEM", "Selecting optimal models via AI...")
|
||||
selected_models = select_best_models(force_refresh=force)
|
||||
|
||||
def get_model_name(role_data):
|
||||
if isinstance(role_data, dict): return role_data.get('model')
|
||||
return role_data
|
||||
|
||||
logic_name = get_model_name(selected_models['logic']) if config.MODEL_LOGIC_HINT == "AUTO" else config.MODEL_LOGIC_HINT
|
||||
writer_name = get_model_name(selected_models['writer']) if config.MODEL_WRITER_HINT == "AUTO" else config.MODEL_WRITER_HINT
|
||||
artist_name = get_model_name(selected_models['artist']) if config.MODEL_ARTIST_HINT == "AUTO" else config.MODEL_ARTIST_HINT
|
||||
utils.log("SYSTEM", f"Models: Logic={logic_name} | Writer={writer_name} | Artist={artist_name}")
|
||||
|
||||
model_logic = genai.GenerativeModel(logic_name, safety_settings=utils.SAFETY_SETTINGS)
|
||||
model_writer = genai.GenerativeModel(writer_name, safety_settings=utils.SAFETY_SETTINGS)
|
||||
model_artist = genai.GenerativeModel(artist_name, safety_settings=utils.SAFETY_SETTINGS)
|
||||
|
||||
# Initialize Image Model (Default to None)
|
||||
model_image = None
|
||||
if hasattr(genai, 'ImageGenerationModel'):
|
||||
try: model_image = genai.ImageGenerationModel("imagen-3.0-generate-001")
|
||||
except: pass
|
||||
|
||||
img_source = "Gemini API" if model_image else "None"
|
||||
|
||||
if HAS_VERTEX and config.GCP_PROJECT:
|
||||
creds = None
|
||||
# Handle OAuth Client ID (credentials.json) if provided instead of Service Account
|
||||
if HAS_OAUTH:
|
||||
gac = config.GOOGLE_CREDS # Use persistent config, not volatile env var
|
||||
if gac and os.path.exists(gac):
|
||||
try:
|
||||
with open(gac, 'r') as f: data = json.load(f)
|
||||
if 'installed' in data or 'web' in data:
|
||||
# It's an OAuth Client ID. Unset env var to avoid library crash.
|
||||
if "GOOGLE_APPLICATION_CREDENTIALS" in os.environ:
|
||||
del os.environ["GOOGLE_APPLICATION_CREDENTIALS"]
|
||||
|
||||
token_path = os.path.join(os.path.dirname(os.path.abspath(gac)), 'token.json')
|
||||
SCOPES = ['https://www.googleapis.com/auth/cloud-platform']
|
||||
|
||||
if os.path.exists(token_path):
|
||||
creds = Credentials.from_authorized_user_file(token_path, SCOPES)
|
||||
|
||||
if not creds or not creds.valid:
|
||||
if creds and creds.expired and creds.refresh_token:
|
||||
try:
|
||||
creds.refresh(Request())
|
||||
except Exception:
|
||||
utils.log("SYSTEM", "Token refresh failed. Re-authenticating...")
|
||||
flow = InstalledAppFlow.from_client_secrets_file(gac, SCOPES)
|
||||
creds = flow.run_local_server(port=0)
|
||||
else:
|
||||
utils.log("SYSTEM", "OAuth Client ID detected. Launching browser to authenticate...")
|
||||
flow = InstalledAppFlow.from_client_secrets_file(gac, SCOPES)
|
||||
creds = flow.run_local_server(port=0)
|
||||
with open(token_path, 'w') as token: token.write(creds.to_json())
|
||||
|
||||
utils.log("SYSTEM", "✅ Authenticated via OAuth Client ID.")
|
||||
except Exception as e:
|
||||
utils.log("SYSTEM", f"⚠️ OAuth check failed: {e}")
|
||||
|
||||
vertexai.init(project=config.GCP_PROJECT, location=config.GCP_LOCATION, credentials=creds)
|
||||
utils.log("SYSTEM", f"✅ Vertex AI initialized (Project: {config.GCP_PROJECT})")
|
||||
|
||||
# Override with Vertex Image Model if available
|
||||
try:
|
||||
model_image = VertexImageModel.from_pretrained("imagen-3.0-generate-001")
|
||||
img_source = "Vertex AI"
|
||||
except: pass
|
||||
|
||||
utils.log("SYSTEM", f"Image Generation Provider: {img_source}")
|
||||
Reference in New Issue
Block a user