diff --git a/src/chat/mod.rs b/src/chat/mod.rs index 679ae85..0c1597b 100644 --- a/src/chat/mod.rs +++ b/src/chat/mod.rs @@ -1,5 +1,5 @@ -use serde::{Deserialize, Serialize}; -use serde::de::{self, Deserializer}; +use serde::de::{self, Deserializer as DeDeserializer, IntoDeserializer, Visitor}; +use serde::{Deserialize, Serialize, Serializer, Deserializer}; use std::borrow::Cow; use std::collections::HashMap; @@ -65,7 +65,7 @@ impl Display for MessageRoles { MessageRoles::Tool => "tool", MessageRoles::User => "user", MessageRoles::Assistant => "assistant", - //HACK: Handle this cleanly, if the model hallucinates we crash :^) + //HACK: Handle this cleanly, if the model hallucinates a role we crash :^) MessageRoles::Other => todo!(), }; @@ -91,7 +91,7 @@ impl Message { // Custom deserializer function fn de_content<'de, D>(deserializer: D) -> Result where - D: Deserializer<'de>, + D: DeDeserializer<'de>, { let s = String::deserialize(deserializer)?; serde_json::from_str(&s).map_err(de::Error::custom) @@ -110,7 +110,9 @@ impl Display for Message { } } -#[derive(Serialize, Deserialize, PartialEq)] +#[derive(Serialize, Deserialize, PartialEq, Clone, Copy, Debug)] +#[serde(rename_all = "snake_case")] +#[serde(untagged)] pub enum AssistantTool { WikiSearch, WebSearch, @@ -125,26 +127,26 @@ impl Display for AssistantTool { let res = match self { AssistantTool::WikiSearch => "wiki_search", AssistantTool::WebSearch => "web_search", - AssistantTool::GetDateTime => "get_datetime", - AssistantTool::GetDirectoryTree => "get_dirtree", - AssistantTool::GetFileContents => "get_file", - //HACK: Handle this cleanly, if the model hallucinates we crash :^) - AssistantTool::InvalidTool => todo!(), + AssistantTool::GetDateTime => "get_date_time", + AssistantTool::GetDirectoryTree => "get_dir_tree", + AssistantTool::GetFileContents => "get_file_contents", + AssistantTool::InvalidTool => "invalid_tool", }; write!(f, "{}", res) } } -#[derive(Serialize, Deserialize, PartialEq)] +#[derive(Serialize, Deserialize, Clone, PartialEq, Debug)] +#[serde(rename_all = "lowercase")] pub enum Action { - ChatMessage, + Chat, Tool(AssistantTool), } impl Display for Action { fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult { match self { - Action::ChatMessage => write!(f, "{}", "chat"), + Action::Chat => write!(f, "{}", "chat"), Action::Tool(tool_name) => write!(f, "{tool_name}"), } } @@ -152,16 +154,13 @@ impl Display for Action { #[derive(Serialize, Deserialize, Clone, PartialEq)] pub struct ActionPacket { - pub action: String, + pub action: Action, pub arguments: HashMap, } impl ActionPacket { pub fn new(action: Action, arguments: HashMap) -> Self { - Self { - action: action.to_string(), - arguments, - } + Self { action, arguments } } } @@ -197,4 +196,4 @@ macro_rules! args_builder { )* map }}; -} \ No newline at end of file +} diff --git a/src/tool/american_crow_example1_message.md b/src/data/american_crow_example1_message.md similarity index 100% rename from src/tool/american_crow_example1_message.md rename to src/data/american_crow_example1_message.md diff --git a/src/tool/american_crow_wikipedia.md b/src/data/american_crow_wikipedia.md similarity index 100% rename from src/tool/american_crow_wikipedia.md rename to src/data/american_crow_wikipedia.md diff --git a/src/tool/black_bear_wikipedia.md b/src/data/black_bear_wikipedia.md similarity index 100% rename from src/tool/black_bear_wikipedia.md rename to src/data/black_bear_wikipedia.md diff --git a/src/tool/tools_list.json b/src/data/tools_list.json similarity index 100% rename from src/tool/tools_list.json rename to src/data/tools_list.json diff --git a/src/main.rs b/src/main.rs index 122c507..fe362d0 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,7 +1,8 @@ use std::borrow::Cow; +use std::pin::Pin; use std::time::{Duration, Instant}; -use chat::Message; +use chat::{Action, Message}; use clap::Parser; use futures_util::StreamExt; use reqwest::Client; @@ -11,17 +12,10 @@ use crossterm::terminal::{ EnterAlternateScreen, LeaveAlternateScreen, disable_raw_mode, enable_raw_mode, }; -use ratatui::{ - Terminal, - backend::CrosstermBackend, - layout::{Constraint, Direction, Layout}, - style::{Color, Style}, - text::{Line, Span}, - widgets::{Block, Borders, Paragraph}, -}; -use serde::{Deserialize, Serialize}; +use ratatui::{Terminal, backend::CrosstermBackend}; mod chat; +mod ui; #[derive(Parser)] struct Args { @@ -36,11 +30,11 @@ struct Args { #[arg( short, long, - help = "Should the response be streamed from ollama or sent all at once" + help = "(Broken) Should the response be streamed from ollama or sent all at once" )] stream: bool, - #[arg(short, long, help = "Show statistics in non-stream mode?")] + #[arg(short, long, help = "(Broken) Show statistics in non-stream mode?")] nerd_stats: bool, } @@ -64,9 +58,9 @@ async fn main() -> anyhow::Result<()> { // setup crossterm enable_raw_mode()?; - let mut stdout = std::io::stdout(); - crossterm::execute!(stdout, EnterAlternateScreen, EnableMouseCapture)?; - let backend = CrosstermBackend::new(stdout); + let mut stdout_handle = std::io::stdout(); + crossterm::execute!(stdout_handle, EnterAlternateScreen, EnableMouseCapture)?; + let backend = CrosstermBackend::new(stdout_handle); let mut terminal = Terminal::new(backend)?; let mut app = App { @@ -78,10 +72,9 @@ async fn main() -> anyhow::Result<()> { let client = Client::new(); - let header_prompt = - r#"SYSTEM: You are "OxiAI", a logical, personal assistant that answers *only* via valid, minified, UTF-8 JSON."#; + let header_prompt = r#"SYSTEM: You are "OxiAI", a logical, personal assistant that answers *only* via valid, minified, UTF-8 JSON."#; - let tools_list = include_str!("tool/tools_list.json") + let tools_list = include_str!("data/tools_list.json") .parse::()? .to_string(); @@ -94,23 +87,6 @@ async fn main() -> anyhow::Result<()> { 6. Base claims strictly on provided data or tool results. If unsure, say so. 7. Check your output; If you reach four consecutive newlines: *stop*"#; - let example_prompt = format!( - "Example 1:{user_q_1}\n{assistant_tool_request_1}\n{tool_result_1}\n{assistant_a_1}", - user_q_1 = r#"user: {"action":"chat", "arguments":{"response":"Provide a summary of the American Crow.", "source":"user"}}"#, - assistant_tool_request_1 = format!( - "assistant: {{ \"action\":\"wiki_search\",\"arguments\":{{\"query\":\"American Crow\"}} }}" - ), - tool_result_1 = format!( - "tool: {{ \"action\":\"wiki_search\",\"arguments\":{{\"result\":\"{search_data}\"}} }}", - search_data = include_str!("tool/american_crow_wikipedia.md").to_string() - ), - assistant_a_1 = format!( - "assistant: {{ \"action\":\"chat\",\"arguments\":{{\"response\":\"{example1_assistant_message}\"}} }}", - example1_assistant_message = - include_str!("tool/american_crow_example1_message.md").to_string() - ) - ); - //let user_info_prompt = r#""#; let system_prompt = format!( "{header_prompt}\n @@ -119,7 +95,7 @@ async fn main() -> anyhow::Result<()> { ); loop { - terminal.draw(|f| chat_ui(f, &app))?; + terminal.draw(|f| ui::chat_ui(f, &app))?; if event::poll(Duration::from_millis(100))? { if let Event::Key(key) = event::read()? { @@ -137,8 +113,9 @@ async fn main() -> anyhow::Result<()> { app.messages.push(chat::Message::new( chat::MessageRoles::User, - chat::Action::ChatMessage, - message_args)); + chat::Action::Chat, + message_args, + )); let mut prompts = vec![chat::Prompt { role: Cow::Borrowed("system"), @@ -236,121 +213,96 @@ async fn batch_ollama_response<'a>( client: Client, req: chat::ChatRequest<'a>, ) -> anyhow::Result<()> { - let start = Instant::now(); - let resp = client - .post("http://localhost:11434/api/chat") - .json(&req) - .send() - .await?; - let elapsed = start.elapsed(); + batch_ollama_response_inner(app, client, req).await +} - let status = resp.status(); - let headers = resp.headers().clone(); - let body_bytes = resp.bytes().await?; - - match serde_json::from_slice::(&body_bytes) { - Ok(r) => app.messages.push(r.message), - Err(e) => { - println!("Failed to parse JSON: {}", e); - println!("Status: {}", status); - println!("Headers: {:#?}", headers); - // Try to print the body as text for debugging - if let Ok(body_text) = std::str::from_utf8(&body_bytes) { - println!("Body text: {}", body_text); - } else { - println!("Body was not valid UTF-8"); +fn batch_ollama_response_inner<'a>( + app: &'a mut App, + client: Client, + req: chat::ChatRequest<'a>, +) -> Pin> + Send + 'a>> { + Box::pin(async move { + let start = Instant::now(); + let resp = client + .post("http://localhost:11434/api/chat") + .json(&req) + .send() + .await?; + let elapsed = start.elapsed(); + + let status = resp.status(); + let headers = resp.headers().clone(); + let body_bytes = resp.bytes().await?; + + match serde_json::from_slice::(&body_bytes) { + Ok(r) => { + match r.message.content.action { + chat::Action::Chat => app.messages.push(r.message), + chat::Action::Tool(assistant_tool) => { + match assistant_tool { + chat::AssistantTool::WikiSearch => { + //HACK: fake it for now, until I figure out how to grab a web page and display it in a way the model understands + let tool_args = r.message.content.arguments.clone(); + app.messages.push(r.message); + + let search_term = match tool_args.get("query") { + Some(v) => v.as_str(), + None => todo!(), + }; + + let tool_response = match search_term { + "American Crow" => { + let r = args_builder! { + "result" => include_str!("data/american_crow_wikipedia.md") + }; + r + } + "Black Bear" => { + let r = args_builder! { + "result" => include_str!("data/black_bear_wikipedia.md") + }; + r + } + _ => { + let r = args_builder! { + "result" => "Search failed to return any valid data" + }; + r + } + }; + + let tool_message = Message::from(( + chat::MessageRoles::Tool, + Action::Tool(chat::AssistantTool::WikiSearch), + tool_response, + )); + app.messages.push(tool_message); + //FIXME: model could recurse forever + batch_ollama_response(app, client.clone(), req).await?; + } + chat::AssistantTool::WebSearch => todo!(), + chat::AssistantTool::GetDateTime => todo!(), + chat::AssistantTool::GetDirectoryTree => todo!(), + chat::AssistantTool::GetFileContents => todo!(), + chat::AssistantTool::InvalidTool => todo!(), + } + } + } + } + Err(e) => { + println!("Failed to parse JSON: {}", e); + println!("Status: {}", status); + println!("Headers: {:#?}", headers); + // Try to print the body as text for debugging + if let Ok(body_text) = std::str::from_utf8(&body_bytes) { + println!("Body text: {}", body_text); + } else { + println!("Body was not valid UTF-8"); + } } } - } - /* - if app.args.nerd_stats { - app.messages.push(format!( - "System : Response generated via {} model with timestamp {}", - resp.model, resp.created_at - )); - - app.messages.push(format!( - "System : done_reason = {}, done = {}", - resp.done_reason, resp.done - )); - - app.messages - .push(format!("System : Response timing statistics...")); - - app.messages - .push(format!("System : Total elapsed wall time: {:.2?}", elapsed)); - app.messages.push(format!( - "System : Prompt tokens: {}", - resp.prompt_eval_count - )); - app.messages.push(format!( - "System : Prompt eval duration: {} ns", - resp.prompt_eval_duration - )); - app.messages - .push(format!("System : Output tokens: {}", resp.eval_count)); - app.messages.push(format!( - "System : Output eval duration: {} ns", - resp.eval_duration - )); - app.messages.push(format!( - "System : Model 'warm up' time {}", - (resp.total_duration - (resp.prompt_eval_duration + resp.eval_duration)) - )); - - let token_speed = resp.eval_count as f64 / (resp.eval_duration as f64 / 1_000_000_000.0); - app.messages.push(format!( - "System > Output generation speed: {:.2} tokens/sec", - token_speed - )); - } - */ - app.waiting = false; - Ok(()) -} - -fn chat_ui(f: &mut ratatui::Frame, app: &App) { - let chunks = Layout::default() - .direction(Direction::Vertical) - .margin(1) - .constraints([Constraint::Min(1), Constraint::Length(3)].as_ref()) - .split(f.area()); - - let messages: Vec = app - .messages - .iter() - .map(|m| Line::from(Span::raw(m.to_string()))) - .collect(); - - let messages_block = Paragraph::new(ratatui::text::Text::from(messages)) - .block(Block::default().borders(Borders::ALL).title("Chat")) - .wrap(ratatui::widgets::Wrap { trim: true }) - .scroll(( - app.messages - .len() - .saturating_sub((chunks[0].height - 2) as usize) as u16, - 0, - )); - - f.render_widget(messages_block, chunks[0]); - - let input_text = if app.waiting { - format!("> {} (waiting...)", &app.prompt) - } else { - format!("> {}", app.prompt) - }; - - let input = Paragraph::new(input_text) - .style(Style::default().fg(Color::Yellow)) - .block(Block::default().borders(Borders::ALL).title("Input")); - f.render_widget(input, chunks[1]); - - use ratatui::layout::Position; - f.set_cursor_position(Position::new( - // the +3 comes from the 3 'characters' of space between the terminal edge and the text location - // this places the text cursor after the last entered character - chunks[1].x + app.prompt.len() as u16 + 3, - chunks[1].y + 1, - )); + app.waiting = false; + Ok(()) + }) } diff --git a/src/tool/mod.rs b/src/tool/mod.rs new file mode 100644 index 0000000..e69de29 diff --git a/src/ui/mod.rs b/src/ui/mod.rs new file mode 100644 index 0000000..dfa271e --- /dev/null +++ b/src/ui/mod.rs @@ -0,0 +1,57 @@ +use ratatui::{ + layout::{Constraint, Direction, Layout}, + style::{Color, Style}, + text::{Line, Span}, + widgets::{Block, Borders, Paragraph}, +}; + +pub fn chat_ui(f: &mut ratatui::Frame, app: &crate::App) { + let chunks = Layout::default() + .direction(Direction::Vertical) + .margin(1) + .constraints([Constraint::Min(1), Constraint::Length(3)].as_ref()) + .split(f.area()); + + let chat_messages: Vec = app + .messages + .iter() + .map(|m| { + Line::from(Span::raw(format!( + "{}: {}", + m.role.to_string(), + m.to_string() + ))) + }) + .collect(); + + let messages_block = Paragraph::new(ratatui::text::Text::from(chat_messages)) + .block(Block::default().borders(Borders::ALL).title("Chat")) + .wrap(ratatui::widgets::Wrap { trim: true }) + .scroll(( + app.messages + .len() + .saturating_sub((chunks[0].height - 2) as usize) as u16, + 0, + )); + + f.render_widget(messages_block, chunks[0]); + + let input_text = if app.waiting { + format!("> {} (waiting...)", &app.prompt) + } else { + format!("> {}", app.prompt) + }; + + let input = Paragraph::new(input_text) + .style(Style::default().fg(Color::Yellow)) + .block(Block::default().borders(Borders::ALL).title("Input")); + f.render_widget(input, chunks[1]); + + use ratatui::layout::Position; + f.set_cursor_position(Position::new( + // the +3 comes from the 3 'characters' of space between the terminal edge and the text location + // this places the text cursor after the last entered character + chunks[1].x + app.prompt.len() as u16 + 3, + chunks[1].y + 1, + )); +}