Compare commits

5 Commits

7 changed files with 350 additions and 287 deletions

1
Cargo.lock generated
View File

@@ -300,6 +300,7 @@ dependencies = [
"crossterm_winapi", "crossterm_winapi",
"derive_more", "derive_more",
"document-features", "document-features",
"futures-core",
"mio", "mio",
"parking_lot", "parking_lot",
"rustix 1.0.5", "rustix 1.0.5",

View File

@@ -4,7 +4,7 @@ version = "0.1.0"
edition = "2024" edition = "2024"
[dependencies] [dependencies]
crossterm = { version = "0.29.0" } crossterm = { version = "0.29.0", features = ["event-stream"]}
ratatui = { version = "0.29.0" } ratatui = { version = "0.29.0" }
anyhow = "1.0" anyhow = "1.0"
reqwest = { version = "0.12", features = ["json", "stream"] } reqwest = { version = "0.12", features = ["json", "stream"] }

40
src/busy_lot.rs Normal file
View File

@@ -0,0 +1,40 @@
use std::sync::Arc;
use std::sync::atomic::{AtomicUsize, Ordering};
/// Shared, lock-free counter of active jobs.
/// Cloning is cheap (Arc); dropping a ticket auto-decrements.
#[derive(Clone)]
pub struct BusyLot {
inner: Arc<AtomicUsize>,
}
/// RAII ticket returned from `BusyLot::park()`.
/// When the ticket is dropped (even on panic) the lot counter goes down.
pub struct Ticket {
lot: BusyLot,
}
impl BusyLot {
pub fn new() -> Self {
Self {
inner: Arc::new(AtomicUsize::new(0)),
}
}
/// Takes a parking space and returns a ticket.
pub fn park(&self) -> Ticket {
self.inner.fetch_add(1, Ordering::AcqRel);
Ticket { lot: self.clone() }
}
/// `true` if at least one ticket is still parked.
pub fn is_busy(&self) -> bool {
self.inner.load(Ordering::Acquire) != 0
}
}
impl Drop for Ticket {
fn drop(&mut self) {
self.lot.inner.fetch_sub(1, Ordering::AcqRel);
}
}

View File

@@ -1,32 +1,21 @@
use serde::de::{self, Deserializer as DeDeserializer, IntoDeserializer, Visitor}; use serde::de::{self, Deserializer as DeDeserializer, IntoDeserializer, Visitor};
use serde::{Deserialize, Serialize, Serializer, Deserializer}; use serde::{Deserialize, Deserializer, Serialize, Serializer};
use std::borrow::Cow; use std::borrow::Cow;
use std::collections::HashMap; use std::collections::HashMap;
use std::fmt::{Display, Formatter, Result as FmtResult}; use std::fmt::{Display, Formatter, Result as FmtResult};
#[derive(Deserialize, Debug)] #[derive(Serialize, Deserialize, Debug)]
pub struct StreamChunk { pub struct Prompt {
pub message: StreamMessage,
}
#[derive(Deserialize, Debug)]
pub struct StreamMessage {
pub role: String, pub role: String,
pub content: String, pub content: String,
} }
#[derive(Serialize, Deserialize, Debug)] impl<'a> From<Message> for Prompt {
pub struct Prompt<'a> {
pub role: Cow<'a, str>,
pub content: Cow<'a, str>,
}
impl<'a> From<Message> for Prompt<'a> {
fn from(message: Message) -> Self { fn from(message: Message) -> Self {
Prompt { Prompt {
role: Cow::Owned(message.role), role: message.role,
content: Cow::Owned(message.content.to_string()), content: message.content.to_string(),
} }
} }
} }
@@ -41,12 +30,12 @@ pub struct ChatOptions {
} }
#[derive(Serialize, Debug)] #[derive(Serialize, Debug)]
pub struct ChatRequest<'a> { pub struct ChatRequest {
pub model: &'a str, pub model: String,
pub messages: Vec<Prompt<'a>>, pub messages: Vec<Prompt>,
pub stream: bool, pub stream: bool,
pub format: &'a str, pub format: String,
pub stop: Vec<&'a str>, pub stop: Vec<String>,
pub options: Option<ChatOptions>, pub options: Option<ChatOptions>,
} }

View File

@@ -21,7 +21,7 @@
"type": "function", "type": "function",
"function": { "function": {
"name": "web_search", "name": "web_search",
"description": "Search DuckDuckGo (a web search engine)", "description": "Search the web using DuckDuckGo (a web search engine)",
"parameters": { "parameters": {
"type": "object", "type": "object",
"properties": { "properties": {

View File

@@ -1,18 +1,21 @@
use crossterm::terminal;
use ratatui::CompletedFrame;
use tokio::sync::mpsc;
use tokio::sync::mpsc::{UnboundedReceiver, UnboundedSender};
use crossterm::event::{
self, DisableMouseCapture, EnableMouseCapture, Event, Event as CEvent, EventStream, KeyCode,
MouseButton,
};
use ratatui::{Frame, Terminal, backend::CrosstermBackend};
use ui::OxiTerminal;
use std::borrow::Cow; use std::borrow::Cow;
use std::pin::Pin;
use std::time::{Duration, Instant};
use chat::{Action, Message}; use chat::{Action, Message};
use clap::Parser; use clap::Parser;
use futures_util::StreamExt; use futures_util::StreamExt;
use reqwest::Client;
use crossterm::event::{self, DisableMouseCapture, EnableMouseCapture, Event, KeyCode};
use crossterm::terminal::{
EnterAlternateScreen, LeaveAlternateScreen, disable_raw_mode, enable_raw_mode,
};
use ratatui::{Terminal, backend::CrosstermBackend};
mod chat; mod chat;
mod ui; mod ui;
@@ -38,100 +41,115 @@ struct Args {
nerd_stats: bool, nerd_stats: bool,
} }
struct App { pub struct Queues {
pub tx_msg: mpsc::UnboundedSender<Msg>, // worker → UI (already exists)
pub rx_msg: mpsc::UnboundedReceiver<Msg>,
pub tx_cmd: mpsc::UnboundedSender<Cmd>, // UI → worker (NEW)
pub rx_cmd: mpsc::UnboundedReceiver<Cmd>,
}
impl Queues {
pub fn new() -> Self {
let (tx_msg, rx_msg) = mpsc::unbounded_channel();
let (tx_cmd, rx_cmd) = mpsc::unbounded_channel();
Queues {
tx_msg,
rx_msg,
tx_cmd,
rx_cmd,
}
}
}
struct AppState {
args: Args, args: Args,
queues: Queues,
prompt: String, prompt: String,
messages: Vec<Message>, messages: Vec<Message>,
waiting: bool, waiting: bool,
system_prompt: String,
} }
#[tokio::main] impl AppState {
async fn main() -> anyhow::Result<()> { const HEADER_PROMPT: &str = r#"SYSTEM: You are "OxiAI", A personal assistant with access to tools. You answer *only* via valid, UTF-8 JSON."#;
// parse arguments
let args = match Args::try_parse() {
Ok(args) => args,
Err(e) => {
e.print().expect("Error writing clap error");
std::process::exit(0);
}
};
// setup crossterm const TOOLS_LIST: &str = include_str!("data/tools_list.json");
enable_raw_mode()?;
let mut stdout_handle = std::io::stdout();
crossterm::execute!(stdout_handle, EnterAlternateScreen, EnableMouseCapture)?;
let backend = CrosstermBackend::new(stdout_handle);
let mut terminal = Terminal::new(backend)?;
let mut app = App { const RULES_PROMPT: &str = r#"Rules:
args,
prompt: String::new(),
messages: vec![],
waiting: false,
};
let client = Client::new();
let header_prompt = r#"SYSTEM: You are "OxiAI", a logical, personal assistant that answers *only* via valid, minified, UTF-8 JSON."#;
let tools_list = include_str!("data/tools_list.json")
.parse::<serde_json::Value>()?
.to_string();
let rules_prompt = r#"Rules:
1. Think silently, Never reveal your chain-of-thought. 1. Think silently, Never reveal your chain-of-thought.
2. To use a tool: {"action":"<tool>","arguments":{...}} 2. To use a tool: {"action":"<tool>","arguments":{...}}
3. To reply directly: {"action":"chat","arguments":{"response":"..."} 3. To reply directly: {"action":"chat","arguments":{"response":"..."}
4. If a question is vague, comparative, descriptive, or about ideas rather than specifics: use the web_search tool. 4. If a question is vague, comparative, descriptive, or about ideas rather than specifics: use the web_search tool.
5. If a question clearly names a specific object, animal, person, place: use the wiki_search tool. 5. If a question clearly names a specific entity, place, or period of time: use the wiki_search tool.
6. Base claims strictly on provided data or tool results. If unsure, say so. 6. Base claims strictly on provided data or tool results. If unsure, say so.
7. Check your output; If you reach four consecutive newlines: *stop*"#; 7. Check your output; If you reach four consecutive newlines: *stop*"#;
//let user_info_prompt = r#""#; pub fn default(args: Args) -> AppState {
let system_prompt = format!( AppState {
"{header_prompt}\n args,
{tools_list}\n\n queues: Queues::new(),
{rules_prompt}\n" prompt: String::new(),
); messages: vec![],
waiting: false,
system_prompt: AppState::get_system_prompt(),
}
}
loop { pub fn get_system_prompt() -> String {
terminal.draw(|f| ui::chat_ui(f, &app))?; format!(
"{}\n{}\n\n{}\n",
AppState::HEADER_PROMPT,
AppState::TOOLS_LIST,
AppState::RULES_PROMPT
)
}
if event::poll(Duration::from_millis(100))? { pub fn handle_http_done(
if let Event::Key(key) = event::read()? { &mut self,
match key.code { result: Result<String, reqwest::Error>,
KeyCode::Char(c) => app.prompt.push(c), ) -> anyhow::Result<()> {
Ok(())
}
pub fn handle_input(&mut self, ev: Event) -> anyhow::Result<Option<Cmd>> {
match ev {
Event::FocusGained => { /* do nothing */ }
Event::FocusLost => { /* do nothing */ }
Event::Key(key_event) => {
match key_event.code {
KeyCode::Char(c) => self.prompt.push(c),
KeyCode::Backspace => { KeyCode::Backspace => {
app.prompt.pop(); let _ = self.prompt.pop();
} }
KeyCode::Enter => { KeyCode::Enter => {
//TODO: refactor to a parser function to take the contents of the app.prompt vec and do fancy stuff with it (like commands) //TODO: refactor to a parser function to take the contents of the app.prompt vec and do fancy stuff with it (like commands)
let message_args = args_builder! { let message_args = args_builder! {
"response" => app.prompt.clone(), "response" => self.prompt.clone(),
}; };
app.prompt.clear(); self.prompt.clear();
app.messages.push(chat::Message::new( self.messages.push(chat::Message::new(
chat::MessageRoles::User, chat::MessageRoles::User,
chat::Action::Chat, chat::Action::Chat,
message_args, message_args,
)); ));
let mut prompts = vec![chat::Prompt { let mut prompts = vec![chat::Prompt {
role: Cow::Borrowed("system"), role: "system".to_string(),
content: Cow::Borrowed(&system_prompt), content: self.system_prompt.clone(),
}]; }];
prompts.extend( prompts.extend(
app.messages self.messages
.iter() .iter()
.map(|msg| chat::Prompt::from(msg.clone())), .map(|msg| chat::Prompt::from(msg.clone())),
); );
let req = chat::ChatRequest { let req = chat::ChatRequest {
model: &app.args.model.clone(), model: self.args.model.clone(),
stream: app.args.stream, stream: self.args.stream,
format: "json", format: "json".to_string(),
stop: vec!["\n\n\n\n"], stop: vec!["\n\n\n\n".to_string()],
options: Some(chat::ChatOptions { options: Some(chat::ChatOptions {
temperature: Some(0.3), temperature: Some(0.3),
top_p: Some(0.92), top_p: Some(0.92),
@@ -142,167 +160,130 @@ async fn main() -> anyhow::Result<()> {
messages: prompts, messages: prompts,
}; };
app.waiting = true; self.waiting = true;
match app.args.stream { return Ok(Some(Cmd::RunChat { req }));
true => {
todo!();
stream_ollama_response(&mut app, client.clone(), req).await?;
} }
false => { _ => { /* ignore all other keys */ }
batch_ollama_response(&mut app, client.clone(), req).await?;
} }
} }
Event::Mouse(mouse_event) => match mouse_event.kind {
event::MouseEventKind::Up(MouseButton::Left) => {}
_ => {}
},
Event::Paste(_) => { /* do nothing */ }
Event::Resize(_, _) => { /* do nothing */ }
} }
KeyCode::Esc => {
Ok(None)
}
}
/// Cmds that can arrive in the command event queue
enum Cmd {
RunChat { req: chat::ChatRequest },
GetAddr,
Quit,
}
/// Messages that can arrive in the UI loop
enum Msg {
Input(CEvent),
HttpDone(Result<String, reqwest::Error>),
}
#[tokio::main]
async fn main() -> anyhow::Result<()> {
// parse arguments from Clap
let args = match Args::try_parse() {
Ok(args) => args,
Err(e) => {
e.print().expect("Error writing clap error");
std::process::exit(0);
}
};
// UI Event Loop
let mut events = EventStream::new();
let mut ticker = tokio::time::interval(std::time::Duration::from_millis(33));
let mut terminal = OxiTerminal::setup();
let mut state = AppState::default(args);
'uiloop: loop {
// first non-blocking drain of all pending messages
while let Ok(msg) = state.queues.rx_msg.try_recv() {
match msg {
Msg::Input(ev) => match ev.as_key_event() {
Some(ke) => {
if ke.code == KeyCode::Esc {
return terminal.term_cleanup();
} else {
if let Some(cmd) = state.handle_input(ev)? {
if state.queues.tx_cmd.send(cmd).is_err() {
break; break;
} }
_ => {} }
}
}
None => {}
},
Msg::HttpDone(r) => state.handle_http_done(r)?,
};
}
// block until either next tick or next user input
tokio::select! {
_ = ticker.tick() => { terminal.do_draw(&state); },
maybe_ev = events.next() => {
if let Some(Ok(ev)) = maybe_ev {
if state.queues.tx_msg.send(Msg::Input(ev)).is_err() { break 'uiloop }
} }
} }
} }
} }
disable_raw_mode()?;
crossterm::execute!(
terminal.backend_mut(),
LeaveAlternateScreen,
DisableMouseCapture
)?;
terminal.show_cursor()?;
Ok(()) Ok(())
} }
//FIXME: streaming replies are harder to work with for now, save this for the future async fn run_workers(
async fn stream_ollama_response( mut rx_cmd: mpsc::UnboundedReceiver<Cmd>,
app: &mut App, tx_msg: mpsc::UnboundedSender<Msg>,
client: Client, model: String,
req: chat::ChatRequest<'_>, ) {
) -> anyhow::Result<()> { while let Some(cmd) = rx_cmd.recv().await {
let mut resp = client match cmd {
Cmd::RunChat { req } => {
let tx_msg = tx_msg.clone();
tokio::spawn(async move {
let res = ollama_call(req).await; // see next section
let _ = tx_msg.send(Msg::HttpDone(res));
});
}
Cmd::GetAddr => {
// --- Kick off an HTTP worker as a proof-of-concept ----
let tx_msg = tx_msg.clone();
tokio::spawn(async move {
let res: Result<String, reqwest::Error> = async {
let resp = reqwest::get("https://ifconfig.me/all").await?;
resp.text().await
}
.await;
let _ = tx_msg.send(Msg::HttpDone(res));
});
}
Cmd::Quit => break,
}
}
}
async fn ollama_call(req: chat::ChatRequest) -> Result<String, reqwest::Error> {
let client = reqwest::Client::new();
client
.post("http://localhost:11434/api/chat") .post("http://localhost:11434/api/chat")
.json(&req) .json(&req)
.send() .send()
.await? .await?
.bytes_stream(); .text()
.await
//TODO: since we haven't decoded the Steam we don't know if its sent the role part of the message
// we'll need to figure out how to 'see the future' so to speak
let mut assistant_line = String::from("Assistant : ");
while let Some(chunk) = resp.next().await {
let chunk = chunk?;
for line in chunk.split(|b| *b == b'\n') {
if line.is_empty() {
continue;
}
let parsed: serde_json::Result<chat::StreamChunk> = serde_json::from_slice(line);
if let Ok(parsed) = parsed {
assistant_line.push_str(&parsed.message.content.to_string());
}
}
}
//FIXME: fix this later
//app.messages.push(assistant_line);
Ok(())
}
async fn batch_ollama_response<'a>(
app: &mut App,
client: Client,
req: chat::ChatRequest<'a>,
) -> anyhow::Result<()> {
batch_ollama_response_inner(app, client, req).await
}
fn batch_ollama_response_inner<'a>(
app: &'a mut App,
client: Client,
req: chat::ChatRequest<'a>,
) -> Pin<Box<dyn Future<Output = anyhow::Result<()>> + Send + 'a>> {
Box::pin(async move {
let start = Instant::now();
let resp = client
.post("http://localhost:11434/api/chat")
.json(&req)
.send()
.await?;
let elapsed = start.elapsed();
let status = resp.status();
let headers = resp.headers().clone();
let body_bytes = resp.bytes().await?;
match serde_json::from_slice::<chat::ChatResponse>(&body_bytes) {
Ok(r) => {
match r.message.content.action {
chat::Action::Chat => app.messages.push(r.message),
chat::Action::Tool(assistant_tool) => {
match assistant_tool {
chat::AssistantTool::WikiSearch => {
//HACK: fake it for now, until I figure out how to grab a web page and display it in a way the model understands
let tool_args = r.message.content.arguments.clone();
app.messages.push(r.message);
let search_term = match tool_args.get("query") {
Some(v) => v.as_str(),
None => todo!(),
};
let tool_response = match search_term {
"American Crow" => {
let r = args_builder! {
"result" => include_str!("data/american_crow_wikipedia.md")
};
r
}
"Black Bear" => {
let r = args_builder! {
"result" => include_str!("data/black_bear_wikipedia.md")
};
r
}
_ => {
let r = args_builder! {
"result" => "Search failed to return any valid data"
};
r
}
};
let tool_message = Message::from((
chat::MessageRoles::Tool,
Action::Tool(chat::AssistantTool::WikiSearch),
tool_response,
));
app.messages.push(tool_message);
//FIXME: model could recurse forever
batch_ollama_response(app, client.clone(), req).await?;
}
chat::AssistantTool::WebSearch => todo!(),
chat::AssistantTool::GetDateTime => todo!(),
chat::AssistantTool::GetDirectoryTree => todo!(),
chat::AssistantTool::GetFileContents => todo!(),
chat::AssistantTool::InvalidTool => todo!(),
}
}
}
}
Err(e) => {
println!("Failed to parse JSON: {}", e);
println!("Status: {}", status);
println!("Headers: {:#?}", headers);
// Try to print the body as text for debugging
if let Ok(body_text) = std::str::from_utf8(&body_bytes) {
println!("Body text: {}", body_text);
} else {
println!("Body was not valid UTF-8");
}
}
}
app.waiting = false;
Ok(())
})
} }

View File

@@ -1,3 +1,17 @@
use std::io::IsTerminal;
use crossterm::terminal::{
EnterAlternateScreen, LeaveAlternateScreen, disable_raw_mode, enable_raw_mode,
};
use crossterm::event::{
self, DisableMouseCapture, EnableMouseCapture, Event, Event as CEvent, EventStream, KeyCode,
MouseButton,
};
use ratatui::CompletedFrame;
use ratatui::prelude::Backend;
use ratatui::{Frame, Terminal, backend::CrosstermBackend};
use ratatui::{ use ratatui::{
layout::{Constraint, Direction, Layout}, layout::{Constraint, Direction, Layout},
style::{Color, Style}, style::{Color, Style},
@@ -5,7 +19,44 @@ use ratatui::{
widgets::{Block, Borders, Paragraph}, widgets::{Block, Borders, Paragraph},
}; };
pub fn chat_ui(f: &mut ratatui::Frame, app: &crate::App) { use crate::AppState;
pub struct OxiTerminal {
handle: Terminal<CrosstermBackend<std::io::Stdout>>,
}
impl OxiTerminal {
pub fn setup() -> Self {
enable_raw_mode(); // crossterm
let mut stdout_handle = std::io::stdout();
crossterm::execute!(stdout_handle, EnterAlternateScreen, EnableMouseCapture);
let backend = CrosstermBackend::new(stdout_handle);
let mut handle = Terminal::new(backend).expect("unable to open a terminal");
handle.clear();
OxiTerminal { handle }
}
pub fn do_draw(&mut self, app: &AppState) -> CompletedFrame {
self.handle
.draw(|f| OxiTerminal::chat_ui(f, app))
.expect("failed to draw to framebuffer")
}
pub fn term_cleanup(&mut self) -> anyhow::Result<()> {
disable_raw_mode()?;
crossterm::execute!(
self.handle.backend_mut(),
LeaveAlternateScreen,
DisableMouseCapture
)?;
self.handle.show_cursor()?;
Ok(())
}
//FIXME: awaiting refactor
pub fn chat_ui(f: &mut ratatui::Frame, app: &crate::AppState) {
let chunks = Layout::default() let chunks = Layout::default()
.direction(Direction::Vertical) .direction(Direction::Vertical)
.margin(1) .margin(1)
@@ -54,4 +105,5 @@ pub fn chat_ui(f: &mut ratatui::Frame, app: &crate::App) {
chunks[1].x + app.prompt.len() as u16 + 3, chunks[1].x + app.prompt.len() as u16 + 3,
chunks[1].y + 1, chunks[1].y + 1,
)); ));
}
} }