Files
Claude_tests/image_query.py
Johannes c24df544b8 1.0
2026-04-10 15:49:53 +02:00

198 lines
7.7 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import anthropic
import base64
import json
import sys
from datetime import datetime
from pathlib import Path
from dotenv import load_dotenv
load_dotenv()
# ── Settings ───────────────────────────────────────────────────────────────────
MODEL = "claude-opus-4-6"
MAX_TOKENS = 4096
TEMPERATURE = None # 0.01.0, None = model default (1.0)
TOP_P = None # 0.01.0, nucleus sampling, None = off
TOP_K = None # int, top-k sampling, None = off
STOP_SEQUENCES = None # e.g. ["END", "STOP"], None = off
# Maps API model IDs → friendly names used in claude_pricing.json
MODEL_NAME_MAP = {
"claude-opus-4-6": "Claude Opus 4.6",
"claude-opus-4-5": "Claude Opus 4.5",
"claude-opus-4-1": "Claude Opus 4.1",
"claude-opus-4-0": "Claude Opus 4",
"claude-sonnet-4-6": "Claude Sonnet 4.6",
"claude-sonnet-4-5": "Claude Sonnet 4.5",
"claude-sonnet-4-0": "Claude Sonnet 4",
"claude-sonnet-3-7": "Claude Sonnet 3.7",
"claude-haiku-4-5": "Claude Haiku 4.5",
"claude-haiku-3-5": "Claude Haiku 3.5",
"claude-haiku-3": "Claude Haiku 3",
"claude-opus-3": "Claude Opus 3",
}
# ── Paths ──────────────────────────────────────────────────────────────────────
LOG_FILE = "log.txt"
PRICING_FILE = Path("claude_pricing.json")
PICS_DIR = Path("pics")
OUTPUT_DIR = Path("output")
MEDIA_TYPES = {
".jpg": "image/jpeg",
".jpeg": "image/jpeg",
".png": "image/png",
".gif": "image/gif",
".webp": "image/webp",
}
# ── Pricing ────────────────────────────────────────────────────────────────────
def load_pricing() -> dict:
with open(PRICING_FILE, encoding="utf-8") as f:
entries = json.load(f)
return {e["model"]: e for e in entries}
def estimate_cost(model_id: str, input_tokens: int, output_tokens: int) -> str:
pricing = load_pricing()
friendly_name = MODEL_NAME_MAP.get(model_id)
if not friendly_name or friendly_name not in pricing:
return "n/a (model not in pricing file)"
rates = pricing[friendly_name]
input_cost = (input_tokens / 1_000_000) * rates["base_input_tokens_per_mtok"]
output_cost = (output_tokens / 1_000_000) * rates["output_tokens_per_mtok"]
total = input_cost + output_cost
return f"${total:.6f} (in: ${input_cost:.6f} out: ${output_cost:.6f})"
# ── Path helpers ───────────────────────────────────────────────────────────────
def resolve_image_path(image_path: str) -> Path:
p = Path(image_path)
if p.parent == Path(".") and not p.is_absolute():
return PICS_DIR / p
return p
def resolve_output_path(output_file: str) -> Path:
p = Path(output_file)
if p.parent == Path(".") and not p.is_absolute():
return OUTPUT_DIR / p
return p
def get_media_type(image_path: str) -> str:
ext = Path(image_path).suffix.lower()
media_type = MEDIA_TYPES.get(ext)
if not media_type:
raise ValueError(f"Unsupported image format: {ext}. Use jpg, png, gif, or webp.")
return media_type
def load_image_b64(image_path: str) -> str:
with open(image_path, "rb") as f:
return base64.standard_b64encode(f.read()).decode("utf-8")
# ── Logging ────────────────────────────────────────────────────────────────────
def write_log(image_path: str, prompt: str, time_sent: datetime, time_received: datetime, usage, output_file: str):
duration = (time_received - time_sent).total_seconds()
cost_str = estimate_cost(MODEL, usage.input_tokens, usage.output_tokens)
active_settings = (
f"temp={TEMPERATURE if TEMPERATURE is not None else 'default'} "
f"max_tokens={MAX_TOKENS} "
f"top_p={TOP_P if TOP_P is not None else 'off'} "
f"top_k={TOP_K if TOP_K is not None else 'off'} "
f"stop_seq={STOP_SEQUENCES if STOP_SEQUENCES is not None else 'off'}"
)
entry = (
f"[{time_sent.strftime('%Y-%m-%d %H:%M:%S')}]\n"
f" model: {MODEL}\n"
f" settings: {active_settings}\n"
f" image: {image_path}\n"
f" prompt: {prompt}\n"
f" output: {output_file}\n"
f" sent: {time_sent.strftime('%H:%M:%S.%f')[:-3]}\n"
f" received: {time_received.strftime('%H:%M:%S.%f')[:-3]}\n"
f" duration: {duration:.2f}s\n"
f" tokens in: {usage.input_tokens}\n"
f" tokens out: {usage.output_tokens}\n"
f" est. cost: {cost_str}\n"
)
with open(LOG_FILE, "a", encoding="utf-8") as f:
f.write(entry + "\n")
# ── Main query ─────────────────────────────────────────────────────────────────
def query_claude(image_path: str, prompt: str, output_file: str = "response.txt") -> str:
client = anthropic.Anthropic()
image_path = str(resolve_image_path(image_path))
output_file = str(resolve_output_path(output_file))
media_type = get_media_type(image_path)
image_data = load_image_b64(image_path)
optional_params = {}
if TEMPERATURE is not None: optional_params["temperature"] = TEMPERATURE
if TOP_P is not None: optional_params["top_p"] = TOP_P
if TOP_K is not None: optional_params["top_k"] = TOP_K
if STOP_SEQUENCES is not None: optional_params["stop_sequences"] = STOP_SEQUENCES
time_sent = datetime.now()
response = client.messages.create(
model=MODEL,
max_tokens=MAX_TOKENS,
messages=[
{
"role": "user",
"content": [
{
"type": "image",
"source": {
"type": "base64",
"media_type": media_type,
"data": image_data,
},
},
{"type": "text", "text": prompt},
],
}
],
**optional_params,
)
time_received = datetime.now()
write_log(image_path, prompt, time_sent, time_received, response.usage, output_file)
return next(block.text for block in response.content if block.type == "text")
def main():
if len(sys.argv) < 3:
print("Usage: python image_query.py <image_path> \"<prompt>\" [output_file]")
sys.exit(1)
image_path = sys.argv[1]
prompt = sys.argv[2]
output_file = sys.argv[3] if len(sys.argv) > 3 else "response.txt"
resolved_image = resolve_image_path(image_path)
if not resolved_image.exists():
print(f"Error: image file '{resolved_image}' not found.")
sys.exit(1)
print(f"Querying Claude about '{resolved_image}'...")
response_text = query_claude(image_path, prompt, output_file)
resolved_output = resolve_output_path(output_file)
print("\n--- Response ---")
print(response_text)
with open(resolved_output, "w", encoding="utf-8") as f:
f.write(response_text)
print(f"\nSaved to {resolved_output}")
if __name__ == "__main__":
main()