bit of refactoring, first attempts at a conversational style.

going to need to really seperate the ui logic since the code is getting too big to just put it all together.
This commit is contained in:
2025-04-29 13:38:15 -04:00
parent bae93b9b3f
commit a8f264bbbc
8 changed files with 178 additions and 170 deletions

View File

@@ -1,5 +1,5 @@
use serde::{Deserialize, Serialize}; use serde::de::{self, Deserializer as DeDeserializer, IntoDeserializer, Visitor};
use serde::de::{self, Deserializer}; use serde::{Deserialize, Serialize, Serializer, Deserializer};
use std::borrow::Cow; use std::borrow::Cow;
use std::collections::HashMap; use std::collections::HashMap;
@@ -65,7 +65,7 @@ impl Display for MessageRoles {
MessageRoles::Tool => "tool", MessageRoles::Tool => "tool",
MessageRoles::User => "user", MessageRoles::User => "user",
MessageRoles::Assistant => "assistant", MessageRoles::Assistant => "assistant",
//HACK: Handle this cleanly, if the model hallucinates we crash :^) //HACK: Handle this cleanly, if the model hallucinates a role we crash :^)
MessageRoles::Other => todo!(), MessageRoles::Other => todo!(),
}; };
@@ -91,7 +91,7 @@ impl Message {
// Custom deserializer function // Custom deserializer function
fn de_content<'de, D>(deserializer: D) -> Result<ActionPacket, D::Error> fn de_content<'de, D>(deserializer: D) -> Result<ActionPacket, D::Error>
where where
D: Deserializer<'de>, D: DeDeserializer<'de>,
{ {
let s = String::deserialize(deserializer)?; let s = String::deserialize(deserializer)?;
serde_json::from_str(&s).map_err(de::Error::custom) serde_json::from_str(&s).map_err(de::Error::custom)
@@ -110,7 +110,9 @@ impl Display for Message {
} }
} }
#[derive(Serialize, Deserialize, PartialEq)] #[derive(Serialize, Deserialize, PartialEq, Clone, Copy, Debug)]
#[serde(rename_all = "snake_case")]
#[serde(untagged)]
pub enum AssistantTool { pub enum AssistantTool {
WikiSearch, WikiSearch,
WebSearch, WebSearch,
@@ -125,26 +127,26 @@ impl Display for AssistantTool {
let res = match self { let res = match self {
AssistantTool::WikiSearch => "wiki_search", AssistantTool::WikiSearch => "wiki_search",
AssistantTool::WebSearch => "web_search", AssistantTool::WebSearch => "web_search",
AssistantTool::GetDateTime => "get_datetime", AssistantTool::GetDateTime => "get_date_time",
AssistantTool::GetDirectoryTree => "get_dirtree", AssistantTool::GetDirectoryTree => "get_dir_tree",
AssistantTool::GetFileContents => "get_file", AssistantTool::GetFileContents => "get_file_contents",
//HACK: Handle this cleanly, if the model hallucinates we crash :^) AssistantTool::InvalidTool => "invalid_tool",
AssistantTool::InvalidTool => todo!(),
}; };
write!(f, "{}", res) write!(f, "{}", res)
} }
} }
#[derive(Serialize, Deserialize, PartialEq)] #[derive(Serialize, Deserialize, Clone, PartialEq, Debug)]
#[serde(rename_all = "lowercase")]
pub enum Action { pub enum Action {
ChatMessage, Chat,
Tool(AssistantTool), Tool(AssistantTool),
} }
impl Display for Action { impl Display for Action {
fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult { fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult {
match self { match self {
Action::ChatMessage => write!(f, "{}", "chat"), Action::Chat => write!(f, "{}", "chat"),
Action::Tool(tool_name) => write!(f, "{tool_name}"), Action::Tool(tool_name) => write!(f, "{tool_name}"),
} }
} }
@@ -152,16 +154,13 @@ impl Display for Action {
#[derive(Serialize, Deserialize, Clone, PartialEq)] #[derive(Serialize, Deserialize, Clone, PartialEq)]
pub struct ActionPacket { pub struct ActionPacket {
pub action: String, pub action: Action,
pub arguments: HashMap<String, String>, pub arguments: HashMap<String, String>,
} }
impl ActionPacket { impl ActionPacket {
pub fn new(action: Action, arguments: HashMap<String, String>) -> Self { pub fn new(action: Action, arguments: HashMap<String, String>) -> Self {
Self { Self { action, arguments }
action: action.to_string(),
arguments,
}
} }
} }

View File

@@ -1,7 +1,8 @@
use std::borrow::Cow; use std::borrow::Cow;
use std::pin::Pin;
use std::time::{Duration, Instant}; use std::time::{Duration, Instant};
use chat::Message; use chat::{Action, Message};
use clap::Parser; use clap::Parser;
use futures_util::StreamExt; use futures_util::StreamExt;
use reqwest::Client; use reqwest::Client;
@@ -11,17 +12,10 @@ use crossterm::terminal::{
EnterAlternateScreen, LeaveAlternateScreen, disable_raw_mode, enable_raw_mode, EnterAlternateScreen, LeaveAlternateScreen, disable_raw_mode, enable_raw_mode,
}; };
use ratatui::{ use ratatui::{Terminal, backend::CrosstermBackend};
Terminal,
backend::CrosstermBackend,
layout::{Constraint, Direction, Layout},
style::{Color, Style},
text::{Line, Span},
widgets::{Block, Borders, Paragraph},
};
use serde::{Deserialize, Serialize};
mod chat; mod chat;
mod ui;
#[derive(Parser)] #[derive(Parser)]
struct Args { struct Args {
@@ -36,11 +30,11 @@ struct Args {
#[arg( #[arg(
short, short,
long, long,
help = "Should the response be streamed from ollama or sent all at once" help = "(Broken) Should the response be streamed from ollama or sent all at once"
)] )]
stream: bool, stream: bool,
#[arg(short, long, help = "Show statistics in non-stream mode?")] #[arg(short, long, help = "(Broken) Show statistics in non-stream mode?")]
nerd_stats: bool, nerd_stats: bool,
} }
@@ -64,9 +58,9 @@ async fn main() -> anyhow::Result<()> {
// setup crossterm // setup crossterm
enable_raw_mode()?; enable_raw_mode()?;
let mut stdout = std::io::stdout(); let mut stdout_handle = std::io::stdout();
crossterm::execute!(stdout, EnterAlternateScreen, EnableMouseCapture)?; crossterm::execute!(stdout_handle, EnterAlternateScreen, EnableMouseCapture)?;
let backend = CrosstermBackend::new(stdout); let backend = CrosstermBackend::new(stdout_handle);
let mut terminal = Terminal::new(backend)?; let mut terminal = Terminal::new(backend)?;
let mut app = App { let mut app = App {
@@ -78,10 +72,9 @@ async fn main() -> anyhow::Result<()> {
let client = Client::new(); let client = Client::new();
let header_prompt = let header_prompt = r#"SYSTEM: You are "OxiAI", a logical, personal assistant that answers *only* via valid, minified, UTF-8 JSON."#;
r#"SYSTEM: You are "OxiAI", a logical, personal assistant that answers *only* via valid, minified, UTF-8 JSON."#;
let tools_list = include_str!("tool/tools_list.json") let tools_list = include_str!("data/tools_list.json")
.parse::<serde_json::Value>()? .parse::<serde_json::Value>()?
.to_string(); .to_string();
@@ -94,23 +87,6 @@ async fn main() -> anyhow::Result<()> {
6. Base claims strictly on provided data or tool results. If unsure, say so. 6. Base claims strictly on provided data or tool results. If unsure, say so.
7. Check your output; If you reach four consecutive newlines: *stop*"#; 7. Check your output; If you reach four consecutive newlines: *stop*"#;
let example_prompt = format!(
"Example 1:{user_q_1}\n{assistant_tool_request_1}\n{tool_result_1}\n{assistant_a_1}",
user_q_1 = r#"user: {"action":"chat", "arguments":{"response":"Provide a summary of the American Crow.", "source":"user"}}"#,
assistant_tool_request_1 = format!(
"assistant: {{ \"action\":\"wiki_search\",\"arguments\":{{\"query\":\"American Crow\"}} }}"
),
tool_result_1 = format!(
"tool: {{ \"action\":\"wiki_search\",\"arguments\":{{\"result\":\"{search_data}\"}} }}",
search_data = include_str!("tool/american_crow_wikipedia.md").to_string()
),
assistant_a_1 = format!(
"assistant: {{ \"action\":\"chat\",\"arguments\":{{\"response\":\"{example1_assistant_message}\"}} }}",
example1_assistant_message =
include_str!("tool/american_crow_example1_message.md").to_string()
)
);
//let user_info_prompt = r#""#; //let user_info_prompt = r#""#;
let system_prompt = format!( let system_prompt = format!(
"{header_prompt}\n "{header_prompt}\n
@@ -119,7 +95,7 @@ async fn main() -> anyhow::Result<()> {
); );
loop { loop {
terminal.draw(|f| chat_ui(f, &app))?; terminal.draw(|f| ui::chat_ui(f, &app))?;
if event::poll(Duration::from_millis(100))? { if event::poll(Duration::from_millis(100))? {
if let Event::Key(key) = event::read()? { if let Event::Key(key) = event::read()? {
@@ -137,8 +113,9 @@ async fn main() -> anyhow::Result<()> {
app.messages.push(chat::Message::new( app.messages.push(chat::Message::new(
chat::MessageRoles::User, chat::MessageRoles::User,
chat::Action::ChatMessage, chat::Action::Chat,
message_args)); message_args,
));
let mut prompts = vec![chat::Prompt { let mut prompts = vec![chat::Prompt {
role: Cow::Borrowed("system"), role: Cow::Borrowed("system"),
@@ -236,121 +213,96 @@ async fn batch_ollama_response<'a>(
client: Client, client: Client,
req: chat::ChatRequest<'a>, req: chat::ChatRequest<'a>,
) -> anyhow::Result<()> { ) -> anyhow::Result<()> {
let start = Instant::now(); batch_ollama_response_inner(app, client, req).await
let resp = client }
.post("http://localhost:11434/api/chat")
.json(&req)
.send()
.await?;
let elapsed = start.elapsed();
let status = resp.status(); fn batch_ollama_response_inner<'a>(
let headers = resp.headers().clone(); app: &'a mut App,
let body_bytes = resp.bytes().await?; client: Client,
req: chat::ChatRequest<'a>,
) -> Pin<Box<dyn Future<Output = anyhow::Result<()>> + Send + 'a>> {
Box::pin(async move {
let start = Instant::now();
let resp = client
.post("http://localhost:11434/api/chat")
.json(&req)
.send()
.await?;
let elapsed = start.elapsed();
match serde_json::from_slice::<chat::ChatResponse>(&body_bytes) { let status = resp.status();
Ok(r) => app.messages.push(r.message), let headers = resp.headers().clone();
Err(e) => { let body_bytes = resp.bytes().await?;
println!("Failed to parse JSON: {}", e);
println!("Status: {}", status); match serde_json::from_slice::<chat::ChatResponse>(&body_bytes) {
println!("Headers: {:#?}", headers); Ok(r) => {
// Try to print the body as text for debugging match r.message.content.action {
if let Ok(body_text) = std::str::from_utf8(&body_bytes) { chat::Action::Chat => app.messages.push(r.message),
println!("Body text: {}", body_text); chat::Action::Tool(assistant_tool) => {
} else { match assistant_tool {
println!("Body was not valid UTF-8"); chat::AssistantTool::WikiSearch => {
//HACK: fake it for now, until I figure out how to grab a web page and display it in a way the model understands
let tool_args = r.message.content.arguments.clone();
app.messages.push(r.message);
let search_term = match tool_args.get("query") {
Some(v) => v.as_str(),
None => todo!(),
};
let tool_response = match search_term {
"American Crow" => {
let r = args_builder! {
"result" => include_str!("data/american_crow_wikipedia.md")
};
r
}
"Black Bear" => {
let r = args_builder! {
"result" => include_str!("data/black_bear_wikipedia.md")
};
r
}
_ => {
let r = args_builder! {
"result" => "Search failed to return any valid data"
};
r
}
};
let tool_message = Message::from((
chat::MessageRoles::Tool,
Action::Tool(chat::AssistantTool::WikiSearch),
tool_response,
));
app.messages.push(tool_message);
//FIXME: model could recurse forever
batch_ollama_response(app, client.clone(), req).await?;
}
chat::AssistantTool::WebSearch => todo!(),
chat::AssistantTool::GetDateTime => todo!(),
chat::AssistantTool::GetDirectoryTree => todo!(),
chat::AssistantTool::GetFileContents => todo!(),
chat::AssistantTool::InvalidTool => todo!(),
}
}
}
}
Err(e) => {
println!("Failed to parse JSON: {}", e);
println!("Status: {}", status);
println!("Headers: {:#?}", headers);
// Try to print the body as text for debugging
if let Ok(body_text) = std::str::from_utf8(&body_bytes) {
println!("Body text: {}", body_text);
} else {
println!("Body was not valid UTF-8");
}
} }
} }
}
/* app.waiting = false;
if app.args.nerd_stats { Ok(())
app.messages.push(format!( })
"System : Response generated via {} model with timestamp {}",
resp.model, resp.created_at
));
app.messages.push(format!(
"System : done_reason = {}, done = {}",
resp.done_reason, resp.done
));
app.messages
.push(format!("System : Response timing statistics..."));
app.messages
.push(format!("System : Total elapsed wall time: {:.2?}", elapsed));
app.messages.push(format!(
"System : Prompt tokens: {}",
resp.prompt_eval_count
));
app.messages.push(format!(
"System : Prompt eval duration: {} ns",
resp.prompt_eval_duration
));
app.messages
.push(format!("System : Output tokens: {}", resp.eval_count));
app.messages.push(format!(
"System : Output eval duration: {} ns",
resp.eval_duration
));
app.messages.push(format!(
"System : Model 'warm up' time {}",
(resp.total_duration - (resp.prompt_eval_duration + resp.eval_duration))
));
let token_speed = resp.eval_count as f64 / (resp.eval_duration as f64 / 1_000_000_000.0);
app.messages.push(format!(
"System > Output generation speed: {:.2} tokens/sec",
token_speed
));
}
*/
app.waiting = false;
Ok(())
}
fn chat_ui(f: &mut ratatui::Frame, app: &App) {
let chunks = Layout::default()
.direction(Direction::Vertical)
.margin(1)
.constraints([Constraint::Min(1), Constraint::Length(3)].as_ref())
.split(f.area());
let messages: Vec<Line> = app
.messages
.iter()
.map(|m| Line::from(Span::raw(m.to_string())))
.collect();
let messages_block = Paragraph::new(ratatui::text::Text::from(messages))
.block(Block::default().borders(Borders::ALL).title("Chat"))
.wrap(ratatui::widgets::Wrap { trim: true })
.scroll((
app.messages
.len()
.saturating_sub((chunks[0].height - 2) as usize) as u16,
0,
));
f.render_widget(messages_block, chunks[0]);
let input_text = if app.waiting {
format!("> {} (waiting...)", &app.prompt)
} else {
format!("> {}", app.prompt)
};
let input = Paragraph::new(input_text)
.style(Style::default().fg(Color::Yellow))
.block(Block::default().borders(Borders::ALL).title("Input"));
f.render_widget(input, chunks[1]);
use ratatui::layout::Position;
f.set_cursor_position(Position::new(
// the +3 comes from the 3 'characters' of space between the terminal edge and the text location
// this places the text cursor after the last entered character
chunks[1].x + app.prompt.len() as u16 + 3,
chunks[1].y + 1,
));
} }

0
src/tool/mod.rs Normal file
View File

57
src/ui/mod.rs Normal file
View File

@@ -0,0 +1,57 @@
use ratatui::{
layout::{Constraint, Direction, Layout},
style::{Color, Style},
text::{Line, Span},
widgets::{Block, Borders, Paragraph},
};
pub fn chat_ui(f: &mut ratatui::Frame, app: &crate::App) {
let chunks = Layout::default()
.direction(Direction::Vertical)
.margin(1)
.constraints([Constraint::Min(1), Constraint::Length(3)].as_ref())
.split(f.area());
let chat_messages: Vec<Line> = app
.messages
.iter()
.map(|m| {
Line::from(Span::raw(format!(
"{}: {}",
m.role.to_string(),
m.to_string()
)))
})
.collect();
let messages_block = Paragraph::new(ratatui::text::Text::from(chat_messages))
.block(Block::default().borders(Borders::ALL).title("Chat"))
.wrap(ratatui::widgets::Wrap { trim: true })
.scroll((
app.messages
.len()
.saturating_sub((chunks[0].height - 2) as usize) as u16,
0,
));
f.render_widget(messages_block, chunks[0]);
let input_text = if app.waiting {
format!("> {} (waiting...)", &app.prompt)
} else {
format!("> {}", app.prompt)
};
let input = Paragraph::new(input_text)
.style(Style::default().fg(Color::Yellow))
.block(Block::default().borders(Borders::ALL).title("Input"));
f.render_widget(input, chunks[1]);
use ratatui::layout::Position;
f.set_cursor_position(Position::new(
// the +3 comes from the 3 'characters' of space between the terminal edge and the text location
// this places the text cursor after the last entered character
chunks[1].x + app.prompt.len() as u16 + 3,
chunks[1].y + 1,
));
}