90 lines
3.3 KiB
Python
90 lines
3.3 KiB
Python
import os
|
|
import json
|
|
import time
|
|
import warnings
|
|
import google.generativeai as genai
|
|
from core import utils
|
|
|
|
# Suppress Vertex AI warnings
|
|
warnings.filterwarnings("ignore", category=UserWarning, module="vertexai")
|
|
|
|
try:
|
|
import vertexai
|
|
from vertexai.preview.vision_models import ImageGenerationModel as VertexImageModel
|
|
HAS_VERTEX = True
|
|
except ImportError:
|
|
HAS_VERTEX = False
|
|
|
|
try:
|
|
from google.auth.transport.requests import Request
|
|
from google.oauth2.credentials import Credentials
|
|
from google_auth_oauthlib.flow import InstalledAppFlow
|
|
HAS_OAUTH = True
|
|
except ImportError:
|
|
HAS_OAUTH = False
|
|
|
|
model_logic = None
|
|
model_writer = None
|
|
model_artist = None
|
|
model_image = None
|
|
logic_model_name = "models/gemini-1.5-flash"
|
|
writer_model_name = "models/gemini-1.5-flash"
|
|
artist_model_name = "models/gemini-1.5-flash"
|
|
pro_model_name = "models/gemini-2.0-pro-exp" # Best available Pro for critical rewrites (prefer free/exp)
|
|
image_model_name = None
|
|
image_model_source = "None"
|
|
|
|
|
|
class ResilientModel:
|
|
def __init__(self, name, safety_settings, role):
|
|
self.name = name
|
|
self.safety_settings = safety_settings
|
|
self.role = role
|
|
self.model = genai.GenerativeModel(name, safety_settings=safety_settings)
|
|
|
|
def update(self, name):
|
|
self.name = name
|
|
self.model = genai.GenerativeModel(name, safety_settings=self.safety_settings)
|
|
|
|
_TOKEN_WARN_LIMIT = 30_000
|
|
|
|
# Timeout in seconds for all generate_content calls (prevents indefinite hangs)
|
|
_GENERATION_TIMEOUT = 180
|
|
|
|
def generate_content(self, *args, **kwargs):
|
|
# Estimate payload size and warn if it exceeds the safe limit
|
|
if args:
|
|
payload = args[0]
|
|
if isinstance(payload, str):
|
|
est = utils.estimate_tokens(payload)
|
|
elif isinstance(payload, list):
|
|
est = sum(utils.estimate_tokens(p) if isinstance(p, str) else 0 for p in payload)
|
|
else:
|
|
est = 0
|
|
if est > self._TOKEN_WARN_LIMIT:
|
|
utils.log("SYSTEM", f"⚠️ Payload warning: ~{est:,} tokens for {self.role} ({self.name}). Consider reducing context.")
|
|
|
|
retries = 0
|
|
max_retries = 3
|
|
base_delay = 5
|
|
|
|
# Inject timeout into request_options without overwriting caller-supplied values
|
|
rq_opts = kwargs.pop("request_options", {}) or {}
|
|
if isinstance(rq_opts, dict):
|
|
rq_opts.setdefault("timeout", self._GENERATION_TIMEOUT)
|
|
|
|
while True:
|
|
try:
|
|
return self.model.generate_content(*args, **kwargs, request_options=rq_opts)
|
|
except Exception as e:
|
|
err_str = str(e).lower()
|
|
is_timeout = "timeout" in err_str or "deadline" in err_str or "timed out" in err_str
|
|
is_retryable = is_timeout or "429" in err_str or "quota" in err_str or "500" in err_str or "503" in err_str or "504" in err_str or "internal error" in err_str
|
|
if is_retryable and retries < max_retries:
|
|
delay = base_delay * (2 ** retries)
|
|
utils.log("SYSTEM", f"⚠️ {'Timeout' if is_timeout else 'API error'} on {self.role} ({self.name}). Retrying in {delay}s... ({retries + 1}/{max_retries})")
|
|
time.sleep(delay)
|
|
retries += 1
|
|
continue
|
|
raise e
|