Files
bookapp/ai/models.py

90 lines
3.3 KiB
Python

import os
import json
import time
import warnings
import google.generativeai as genai
from core import utils
# Suppress Vertex AI warnings
warnings.filterwarnings("ignore", category=UserWarning, module="vertexai")
try:
import vertexai
from vertexai.preview.vision_models import ImageGenerationModel as VertexImageModel
HAS_VERTEX = True
except ImportError:
HAS_VERTEX = False
try:
from google.auth.transport.requests import Request
from google.oauth2.credentials import Credentials
from google_auth_oauthlib.flow import InstalledAppFlow
HAS_OAUTH = True
except ImportError:
HAS_OAUTH = False
model_logic = None
model_writer = None
model_artist = None
model_image = None
logic_model_name = "models/gemini-1.5-flash"
writer_model_name = "models/gemini-1.5-flash"
artist_model_name = "models/gemini-1.5-flash"
pro_model_name = "models/gemini-2.0-pro-exp" # Best available Pro for critical rewrites (prefer free/exp)
image_model_name = None
image_model_source = "None"
class ResilientModel:
def __init__(self, name, safety_settings, role):
self.name = name
self.safety_settings = safety_settings
self.role = role
self.model = genai.GenerativeModel(name, safety_settings=safety_settings)
def update(self, name):
self.name = name
self.model = genai.GenerativeModel(name, safety_settings=self.safety_settings)
_TOKEN_WARN_LIMIT = 30_000
# Timeout in seconds for all generate_content calls (prevents indefinite hangs)
_GENERATION_TIMEOUT = 180
def generate_content(self, *args, **kwargs):
# Estimate payload size and warn if it exceeds the safe limit
if args:
payload = args[0]
if isinstance(payload, str):
est = utils.estimate_tokens(payload)
elif isinstance(payload, list):
est = sum(utils.estimate_tokens(p) if isinstance(p, str) else 0 for p in payload)
else:
est = 0
if est > self._TOKEN_WARN_LIMIT:
utils.log("SYSTEM", f"⚠️ Payload warning: ~{est:,} tokens for {self.role} ({self.name}). Consider reducing context.")
retries = 0
max_retries = 3
base_delay = 5
# Inject timeout into request_options without overwriting caller-supplied values
rq_opts = kwargs.pop("request_options", {}) or {}
if isinstance(rq_opts, dict):
rq_opts.setdefault("timeout", self._GENERATION_TIMEOUT)
while True:
try:
return self.model.generate_content(*args, **kwargs, request_options=rq_opts)
except Exception as e:
err_str = str(e).lower()
is_timeout = "timeout" in err_str or "deadline" in err_str or "timed out" in err_str
is_retryable = is_timeout or "429" in err_str or "quota" in err_str or "500" in err_str or "503" in err_str or "504" in err_str or "internal error" in err_str
if is_retryable and retries < max_retries:
delay = base_delay * (2 ** retries)
utils.log("SYSTEM", f"⚠️ {'Timeout' if is_timeout else 'API error'} on {self.role} ({self.name}). Retrying in {delay}s... ({retries + 1}/{max_retries})")
time.sleep(delay)
retries += 1
continue
raise e