#!/Users/bradbrown/whisper-ptt-env/bin/python3.13
import sys
import threading
import tempfile
import os
import urllib.request
import urllib.error
import ssl
import json
import certifi
import numpy as np
import sounddevice as sd
import soundfile as sf
from pynput import keyboard

DEEPGRAM_API_KEY = os.environ.get("DEEPGRAM_API_KEY", "")
SAMPLE_RATE = 16000

recording = False
audio_chunks = []
lock = threading.Lock()


def start_recording():
    global recording, audio_chunks
    with lock:
        recording = True
        audio_chunks = []
    print("[PTT] Recording...", flush=True)
    # Play a subtle beep to signal start
    os.system("afplay /System/Library/Sounds/Tink.aiff &")


def stop_and_transcribe():
    global recording
    with lock:
        recording = False
    print("[PTT] Stopped. Transcribing...", flush=True)
    os.system("afplay /System/Library/Sounds/Pop.aiff &")

    with lock:
        chunks = list(audio_chunks)

    if not chunks:
        print("[PTT] No audio captured.", flush=True)
        return

    audio = np.concatenate(chunks, axis=0)
    with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f:
        tmp_path = f.name

    try:
        sf.write(tmp_path, audio, SAMPLE_RATE)
        with open(tmp_path, "rb") as f:
            audio_data = f.read()
        req = urllib.request.Request(
            "https://api.deepgram.com/v1/listen?model=nova-2",
            data=audio_data,
            headers={
                "Authorization": f"Token {DEEPGRAM_API_KEY}",
                "Content-Type": "audio/wav",
            },
            method="POST",
        )
        ctx = ssl.create_default_context(cafile=certifi.where())
        with urllib.request.urlopen(req, timeout=30, context=ctx) as resp:
            result = json.loads(resp.read().decode())
        text = result["results"]["channels"][0]["alternatives"][0]["transcript"].strip()
        if text:
            print(f"[PTT] Transcribed: {text}", flush=True)
            kb = keyboard.Controller()
            words = text.split(' ')
            for i, word in enumerate(words):
                kb.type(word)
                if i < len(words) - 1:
                    kb.press(keyboard.Key.space)
                    kb.release(keyboard.Key.space)
        else:
            print("[PTT] No transcription output.", flush=True)
    except urllib.error.URLError as e:
        print(f"[PTT] Deepgram request failed: {e}", flush=True)
    except (KeyError, IndexError, json.JSONDecodeError) as e:
        print(f"[PTT] Failed to parse Deepgram response: {e}", flush=True)
    finally:
        os.unlink(tmp_path)


def audio_callback(indata, frames, time, status):
    if recording:
        with lock:
            audio_chunks.append(indata.copy())


# Track modifier and key state
ctrl_held = False
i_held = False


def on_press(key):
    global ctrl_held, i_held
    if key in (keyboard.Key.ctrl, keyboard.Key.ctrl_l, keyboard.Key.ctrl_r):
        ctrl_held = True
        return
    try:
        if key.char == "i" and ctrl_held and not i_held:
            i_held = True
            start_recording()
    except AttributeError:
        pass


def on_release(key):
    global ctrl_held, i_held
    if key in (keyboard.Key.ctrl, keyboard.Key.ctrl_l, keyboard.Key.ctrl_r):
        ctrl_held = False
        return
    try:
        if key.char == "i" and i_held:
            i_held = False
            threading.Thread(target=stop_and_transcribe, daemon=True).start()
    except AttributeError:
        pass


if __name__ == "__main__":
    if not DEEPGRAM_API_KEY:
        print("[PTT] DEEPGRAM_API_KEY not set.", file=sys.stderr)
        sys.exit(1)

    print("[PTT] Starting. Hold Ctrl+I to record, release to transcribe.", flush=True)
    print("[PTT] Ctrl+C to quit.", flush=True)

    with sd.InputStream(samplerate=SAMPLE_RATE, channels=1,
                        dtype="float32", callback=audio_callback):
        with keyboard.Listener(on_press=on_press, on_release=on_release) as listener:
            try:
                listener.join()
            except KeyboardInterrupt:
                print("\n[PTT] Exiting.", flush=True)
