More improvements.
This commit is contained in:
113
modules/ai.py
113
modules/ai.py
@@ -28,6 +28,45 @@ model_logic = None
|
||||
model_writer = None
|
||||
model_artist = None
|
||||
model_image = None
|
||||
logic_model_name = "models/gemini-1.5-pro"
|
||||
writer_model_name = "models/gemini-1.5-flash"
|
||||
artist_model_name = "models/gemini-1.5-flash"
|
||||
|
||||
class ResilientModel:
|
||||
def __init__(self, name, safety_settings, role):
|
||||
self.name = name
|
||||
self.safety_settings = safety_settings
|
||||
self.role = role
|
||||
self.model = genai.GenerativeModel(name, safety_settings=safety_settings)
|
||||
|
||||
def update(self, name):
|
||||
self.name = name
|
||||
self.model = genai.GenerativeModel(name, safety_settings=self.safety_settings)
|
||||
|
||||
def generate_content(self, *args, **kwargs):
|
||||
retries = 0
|
||||
max_retries = 3
|
||||
base_delay = 5
|
||||
|
||||
while True:
|
||||
try:
|
||||
return self.model.generate_content(*args, **kwargs)
|
||||
except Exception as e:
|
||||
is_quota = "429" in str(e) or "quota" in str(e).lower()
|
||||
if is_quota and retries < max_retries:
|
||||
delay = base_delay * (2 ** retries)
|
||||
utils.log("SYSTEM", f"⚠️ Quota error on {self.role} ({self.name}). Retrying in {delay}s...")
|
||||
time.sleep(delay)
|
||||
|
||||
# On first retry, attempt to re-optimize/rotate models
|
||||
if retries == 0:
|
||||
utils.log("SYSTEM", "Attempting to re-optimize models to find alternative...")
|
||||
init_models(force=True)
|
||||
# Note: init_models calls .update() on this instance
|
||||
|
||||
retries += 1
|
||||
continue
|
||||
raise e
|
||||
|
||||
def get_optimal_model(base_type="pro"):
|
||||
try:
|
||||
@@ -44,9 +83,9 @@ def get_optimal_model(base_type="pro"):
|
||||
|
||||
def get_default_models():
|
||||
return {
|
||||
"logic": {"model": "models/gemini-1.5-pro", "reason": "Fallback: Default Pro model selected."},
|
||||
"writer": {"model": "models/gemini-1.5-flash", "reason": "Fallback: Default Flash model selected."},
|
||||
"artist": {"model": "models/gemini-1.5-flash", "reason": "Fallback: Default Flash model selected."},
|
||||
"logic": {"model": "models/gemini-1.5-pro", "reason": "Fallback: Default Pro model selected.", "estimated_cost": "$3.50/1M"},
|
||||
"writer": {"model": "models/gemini-1.5-flash", "reason": "Fallback: Default Flash model selected.", "estimated_cost": "$0.075/1M"},
|
||||
"artist": {"model": "models/gemini-1.5-flash", "reason": "Fallback: Default Flash model selected.", "estimated_cost": "$0.075/1M"},
|
||||
"ranking": []
|
||||
}
|
||||
|
||||
@@ -78,11 +117,8 @@ def select_best_models(force_refresh=False):
|
||||
utils.log("SYSTEM", "Refreshing AI model list from API...")
|
||||
models = [m.name for m in genai.list_models() if 'generateContent' in m.supported_generation_methods and 'gemini' in m.name.lower()]
|
||||
|
||||
bootstrapper = "models/gemini-1.5-flash"
|
||||
if bootstrapper not in models:
|
||||
candidates = [m for m in models if 'flash' in m]
|
||||
bootstrapper = candidates[0] if candidates else "models/gemini-pro"
|
||||
utils.log("SYSTEM", f"Bootstrapping model selection with: {bootstrapper}")
|
||||
bootstrapper = get_optimal_model("flash")
|
||||
utils.log("SYSTEM", f"Bootstrapping model selection with: {bootstrapper}")
|
||||
|
||||
model = genai.GenerativeModel(bootstrapper)
|
||||
prompt = f"""
|
||||
@@ -92,6 +128,10 @@ def select_best_models(force_refresh=False):
|
||||
AVAILABLE_MODELS:
|
||||
{json.dumps(models)}
|
||||
|
||||
PRICING_CONTEXT (USD per 1M tokens):
|
||||
- Flash Models (e.g. gemini-1.5-flash): ~$0.075 Input / $0.30 Output. (Very Cheap)
|
||||
- Pro Models (e.g. gemini-1.5-pro): ~$3.50 Input / $10.50 Output. (Expensive)
|
||||
|
||||
CRITERIA:
|
||||
- LOGIC: Needs complex reasoning, JSON adherence, and instruction following. (Prefer Pro/1.5).
|
||||
- WRITER: Needs creativity, prose quality, and speed. (Prefer Flash/1.5 for speed, or Pro for quality).
|
||||
@@ -103,10 +143,10 @@ def select_best_models(force_refresh=False):
|
||||
|
||||
OUTPUT_FORMAT (JSON):
|
||||
{{
|
||||
"logic": {{ "model": "string", "reason": "string" }},
|
||||
"writer": {{ "model": "string", "reason": "string" }},
|
||||
"artist": {{ "model": "string", "reason": "string" }},
|
||||
"ranking": [ {{ "model": "string", "reason": "string" }} ]
|
||||
"logic": {{ "model": "string", "reason": "string", "estimated_cost": "$X.XX Input / $X.XX Output" }},
|
||||
"writer": {{ "model": "string", "reason": "string", "estimated_cost": "$X.XX Input / $X.XX Output" }},
|
||||
"artist": {{ "model": "string", "reason": "string", "estimated_cost": "$X.XX Input / $X.XX Output" }},
|
||||
"ranking": [ {{ "model": "string", "reason": "string", "estimated_cost": "string" }} ]
|
||||
}}
|
||||
"""
|
||||
|
||||
@@ -142,7 +182,7 @@ def select_best_models(force_refresh=False):
|
||||
return fallback
|
||||
|
||||
def init_models(force=False):
|
||||
global model_logic, model_writer, model_artist, model_image
|
||||
global model_logic, model_writer, model_artist, model_image, logic_model_name, writer_model_name, artist_model_name
|
||||
if model_logic and not force: return
|
||||
genai.configure(api_key=config.API_KEY)
|
||||
|
||||
@@ -171,18 +211,45 @@ def init_models(force=False):
|
||||
utils.log("SYSTEM", "Selecting optimal models via AI...")
|
||||
selected_models = select_best_models(force_refresh=force)
|
||||
|
||||
def get_model_name(role_data):
|
||||
if isinstance(role_data, dict): return role_data.get('model')
|
||||
return role_data
|
||||
# Check for missing costs and force refresh if needed
|
||||
if not force:
|
||||
missing_costs = False
|
||||
for role in ['logic', 'writer', 'artist']:
|
||||
if 'estimated_cost' not in selected_models.get(role, {}) or selected_models[role].get('estimated_cost') == 'N/A':
|
||||
missing_costs = True
|
||||
if missing_costs:
|
||||
utils.log("SYSTEM", "⚠️ Missing cost info in cached models. Forcing refresh.")
|
||||
return init_models(force=True)
|
||||
|
||||
logic_name = get_model_name(selected_models['logic']) if config.MODEL_LOGIC_HINT == "AUTO" else config.MODEL_LOGIC_HINT
|
||||
writer_name = get_model_name(selected_models['writer']) if config.MODEL_WRITER_HINT == "AUTO" else config.MODEL_WRITER_HINT
|
||||
artist_name = get_model_name(selected_models['artist']) if config.MODEL_ARTIST_HINT == "AUTO" else config.MODEL_ARTIST_HINT
|
||||
utils.log("SYSTEM", f"Models: Logic={logic_name} | Writer={writer_name} | Artist={artist_name}")
|
||||
def get_model_details(role_data):
|
||||
if isinstance(role_data, dict): return role_data.get('model'), role_data.get('estimated_cost', 'N/A')
|
||||
return role_data, 'N/A'
|
||||
|
||||
logic_name, logic_cost = get_model_details(selected_models['logic'])
|
||||
writer_name, writer_cost = get_model_details(selected_models['writer'])
|
||||
artist_name, artist_cost = get_model_details(selected_models['artist'])
|
||||
|
||||
logic_name = logic_model_name = logic_name if config.MODEL_LOGIC_HINT == "AUTO" else config.MODEL_LOGIC_HINT
|
||||
writer_name = writer_model_name = writer_name if config.MODEL_WRITER_HINT == "AUTO" else config.MODEL_WRITER_HINT
|
||||
artist_name = artist_model_name = artist_name if config.MODEL_ARTIST_HINT == "AUTO" else config.MODEL_ARTIST_HINT
|
||||
|
||||
model_logic = genai.GenerativeModel(logic_name, safety_settings=utils.SAFETY_SETTINGS)
|
||||
model_writer = genai.GenerativeModel(writer_name, safety_settings=utils.SAFETY_SETTINGS)
|
||||
model_artist = genai.GenerativeModel(artist_name, safety_settings=utils.SAFETY_SETTINGS)
|
||||
utils.log("SYSTEM", f"Models: Logic={logic_name} ({logic_cost}) | Writer={writer_name} ({writer_cost}) | Artist={artist_name}")
|
||||
|
||||
# Update pricing in utils
|
||||
utils.update_pricing(logic_name, logic_cost)
|
||||
utils.update_pricing(writer_name, writer_cost)
|
||||
utils.update_pricing(artist_name, artist_cost)
|
||||
|
||||
# Initialize or Update Resilient Models
|
||||
if model_logic is None:
|
||||
model_logic = ResilientModel(logic_name, utils.SAFETY_SETTINGS, "Logic")
|
||||
model_writer = ResilientModel(writer_name, utils.SAFETY_SETTINGS, "Writer")
|
||||
model_artist = ResilientModel(artist_name, utils.SAFETY_SETTINGS, "Artist")
|
||||
else:
|
||||
# If models already exist (re-init), update them in place
|
||||
model_logic.update(logic_name)
|
||||
model_writer.update(writer_name)
|
||||
model_artist.update(artist_name)
|
||||
|
||||
# Initialize Image Model (Default to None)
|
||||
model_image = None
|
||||
|
||||
Reference in New Issue
Block a user