198 lines
7.7 KiB
Python
198 lines
7.7 KiB
Python
import anthropic
|
||
import base64
|
||
import json
|
||
import sys
|
||
from datetime import datetime
|
||
from pathlib import Path
|
||
from dotenv import load_dotenv
|
||
|
||
load_dotenv()
|
||
|
||
# ── Settings ───────────────────────────────────────────────────────────────────
|
||
MODEL = "claude-opus-4-6"
|
||
MAX_TOKENS = 4096
|
||
TEMPERATURE = None # 0.0–1.0, None = model default (1.0)
|
||
TOP_P = None # 0.0–1.0, nucleus sampling, None = off
|
||
TOP_K = None # int, top-k sampling, None = off
|
||
STOP_SEQUENCES = None # e.g. ["END", "STOP"], None = off
|
||
|
||
# Maps API model IDs → friendly names used in claude_pricing.json
|
||
MODEL_NAME_MAP = {
|
||
"claude-opus-4-6": "Claude Opus 4.6",
|
||
"claude-opus-4-5": "Claude Opus 4.5",
|
||
"claude-opus-4-1": "Claude Opus 4.1",
|
||
"claude-opus-4-0": "Claude Opus 4",
|
||
"claude-sonnet-4-6": "Claude Sonnet 4.6",
|
||
"claude-sonnet-4-5": "Claude Sonnet 4.5",
|
||
"claude-sonnet-4-0": "Claude Sonnet 4",
|
||
"claude-sonnet-3-7": "Claude Sonnet 3.7",
|
||
"claude-haiku-4-5": "Claude Haiku 4.5",
|
||
"claude-haiku-3-5": "Claude Haiku 3.5",
|
||
"claude-haiku-3": "Claude Haiku 3",
|
||
"claude-opus-3": "Claude Opus 3",
|
||
}
|
||
|
||
# ── Paths ──────────────────────────────────────────────────────────────────────
|
||
LOG_FILE = "log.txt"
|
||
PRICING_FILE = Path("claude_pricing.json")
|
||
PICS_DIR = Path("pics")
|
||
OUTPUT_DIR = Path("output")
|
||
|
||
MEDIA_TYPES = {
|
||
".jpg": "image/jpeg",
|
||
".jpeg": "image/jpeg",
|
||
".png": "image/png",
|
||
".gif": "image/gif",
|
||
".webp": "image/webp",
|
||
}
|
||
|
||
|
||
# ── Pricing ────────────────────────────────────────────────────────────────────
|
||
def load_pricing() -> dict:
|
||
with open(PRICING_FILE, encoding="utf-8") as f:
|
||
entries = json.load(f)
|
||
return {e["model"]: e for e in entries}
|
||
|
||
|
||
def estimate_cost(model_id: str, input_tokens: int, output_tokens: int) -> str:
|
||
pricing = load_pricing()
|
||
friendly_name = MODEL_NAME_MAP.get(model_id)
|
||
if not friendly_name or friendly_name not in pricing:
|
||
return "n/a (model not in pricing file)"
|
||
rates = pricing[friendly_name]
|
||
input_cost = (input_tokens / 1_000_000) * rates["base_input_tokens_per_mtok"]
|
||
output_cost = (output_tokens / 1_000_000) * rates["output_tokens_per_mtok"]
|
||
total = input_cost + output_cost
|
||
return f"${total:.6f} (in: ${input_cost:.6f} out: ${output_cost:.6f})"
|
||
|
||
|
||
# ── Path helpers ───────────────────────────────────────────────────────────────
|
||
def resolve_image_path(image_path: str) -> Path:
|
||
p = Path(image_path)
|
||
if p.parent == Path(".") and not p.is_absolute():
|
||
return PICS_DIR / p
|
||
return p
|
||
|
||
|
||
def resolve_output_path(output_file: str) -> Path:
|
||
p = Path(output_file)
|
||
if p.parent == Path(".") and not p.is_absolute():
|
||
return OUTPUT_DIR / p
|
||
return p
|
||
|
||
|
||
def get_media_type(image_path: str) -> str:
|
||
ext = Path(image_path).suffix.lower()
|
||
media_type = MEDIA_TYPES.get(ext)
|
||
if not media_type:
|
||
raise ValueError(f"Unsupported image format: {ext}. Use jpg, png, gif, or webp.")
|
||
return media_type
|
||
|
||
|
||
def load_image_b64(image_path: str) -> str:
|
||
with open(image_path, "rb") as f:
|
||
return base64.standard_b64encode(f.read()).decode("utf-8")
|
||
|
||
|
||
# ── Logging ────────────────────────────────────────────────────────────────────
|
||
def write_log(image_path: str, prompt: str, time_sent: datetime, time_received: datetime, usage, output_file: str):
|
||
duration = (time_received - time_sent).total_seconds()
|
||
cost_str = estimate_cost(MODEL, usage.input_tokens, usage.output_tokens)
|
||
active_settings = (
|
||
f"temp={TEMPERATURE if TEMPERATURE is not None else 'default'} "
|
||
f"max_tokens={MAX_TOKENS} "
|
||
f"top_p={TOP_P if TOP_P is not None else 'off'} "
|
||
f"top_k={TOP_K if TOP_K is not None else 'off'} "
|
||
f"stop_seq={STOP_SEQUENCES if STOP_SEQUENCES is not None else 'off'}"
|
||
)
|
||
entry = (
|
||
f"[{time_sent.strftime('%Y-%m-%d %H:%M:%S')}]\n"
|
||
f" model: {MODEL}\n"
|
||
f" settings: {active_settings}\n"
|
||
f" image: {image_path}\n"
|
||
f" prompt: {prompt}\n"
|
||
f" output: {output_file}\n"
|
||
f" sent: {time_sent.strftime('%H:%M:%S.%f')[:-3]}\n"
|
||
f" received: {time_received.strftime('%H:%M:%S.%f')[:-3]}\n"
|
||
f" duration: {duration:.2f}s\n"
|
||
f" tokens in: {usage.input_tokens}\n"
|
||
f" tokens out: {usage.output_tokens}\n"
|
||
f" est. cost: {cost_str}\n"
|
||
)
|
||
with open(LOG_FILE, "a", encoding="utf-8") as f:
|
||
f.write(entry + "\n")
|
||
|
||
|
||
# ── Main query ─────────────────────────────────────────────────────────────────
|
||
def query_claude(image_path: str, prompt: str, output_file: str = "response.txt") -> str:
|
||
client = anthropic.Anthropic()
|
||
image_path = str(resolve_image_path(image_path))
|
||
output_file = str(resolve_output_path(output_file))
|
||
media_type = get_media_type(image_path)
|
||
image_data = load_image_b64(image_path)
|
||
|
||
optional_params = {}
|
||
if TEMPERATURE is not None: optional_params["temperature"] = TEMPERATURE
|
||
if TOP_P is not None: optional_params["top_p"] = TOP_P
|
||
if TOP_K is not None: optional_params["top_k"] = TOP_K
|
||
if STOP_SEQUENCES is not None: optional_params["stop_sequences"] = STOP_SEQUENCES
|
||
|
||
time_sent = datetime.now()
|
||
response = client.messages.create(
|
||
model=MODEL,
|
||
max_tokens=MAX_TOKENS,
|
||
messages=[
|
||
{
|
||
"role": "user",
|
||
"content": [
|
||
{
|
||
"type": "image",
|
||
"source": {
|
||
"type": "base64",
|
||
"media_type": media_type,
|
||
"data": image_data,
|
||
},
|
||
},
|
||
{"type": "text", "text": prompt},
|
||
],
|
||
}
|
||
],
|
||
**optional_params,
|
||
)
|
||
time_received = datetime.now()
|
||
|
||
write_log(image_path, prompt, time_sent, time_received, response.usage, output_file)
|
||
|
||
return next(block.text for block in response.content if block.type == "text")
|
||
|
||
|
||
def main():
|
||
if len(sys.argv) < 3:
|
||
print("Usage: python image_query.py <image_path> \"<prompt>\" [output_file]")
|
||
sys.exit(1)
|
||
|
||
image_path = sys.argv[1]
|
||
prompt = sys.argv[2]
|
||
output_file = sys.argv[3] if len(sys.argv) > 3 else "response.txt"
|
||
|
||
resolved_image = resolve_image_path(image_path)
|
||
if not resolved_image.exists():
|
||
print(f"Error: image file '{resolved_image}' not found.")
|
||
sys.exit(1)
|
||
|
||
print(f"Querying Claude about '{resolved_image}'...")
|
||
response_text = query_claude(image_path, prompt, output_file)
|
||
|
||
resolved_output = resolve_output_path(output_file)
|
||
print("\n--- Response ---")
|
||
print(response_text)
|
||
|
||
with open(resolved_output, "w", encoding="utf-8") as f:
|
||
f.write(response_text)
|
||
|
||
print(f"\nSaved to {resolved_output}")
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|