bit of refactoring, first attempts at a conversational style.

going to need to really seperate the ui logic since the code is getting too big to just put it all together.
This commit is contained in:
2025-04-29 13:38:15 -04:00
parent bae93b9b3f
commit a8f264bbbc
8 changed files with 178 additions and 170 deletions

View File

@@ -1,7 +1,8 @@
use std::borrow::Cow;
use std::pin::Pin;
use std::time::{Duration, Instant};
use chat::Message;
use chat::{Action, Message};
use clap::Parser;
use futures_util::StreamExt;
use reqwest::Client;
@@ -11,17 +12,10 @@ use crossterm::terminal::{
EnterAlternateScreen, LeaveAlternateScreen, disable_raw_mode, enable_raw_mode,
};
use ratatui::{
Terminal,
backend::CrosstermBackend,
layout::{Constraint, Direction, Layout},
style::{Color, Style},
text::{Line, Span},
widgets::{Block, Borders, Paragraph},
};
use serde::{Deserialize, Serialize};
use ratatui::{Terminal, backend::CrosstermBackend};
mod chat;
mod ui;
#[derive(Parser)]
struct Args {
@@ -36,11 +30,11 @@ struct Args {
#[arg(
short,
long,
help = "Should the response be streamed from ollama or sent all at once"
help = "(Broken) Should the response be streamed from ollama or sent all at once"
)]
stream: bool,
#[arg(short, long, help = "Show statistics in non-stream mode?")]
#[arg(short, long, help = "(Broken) Show statistics in non-stream mode?")]
nerd_stats: bool,
}
@@ -64,9 +58,9 @@ async fn main() -> anyhow::Result<()> {
// setup crossterm
enable_raw_mode()?;
let mut stdout = std::io::stdout();
crossterm::execute!(stdout, EnterAlternateScreen, EnableMouseCapture)?;
let backend = CrosstermBackend::new(stdout);
let mut stdout_handle = std::io::stdout();
crossterm::execute!(stdout_handle, EnterAlternateScreen, EnableMouseCapture)?;
let backend = CrosstermBackend::new(stdout_handle);
let mut terminal = Terminal::new(backend)?;
let mut app = App {
@@ -78,10 +72,9 @@ async fn main() -> anyhow::Result<()> {
let client = Client::new();
let header_prompt =
r#"SYSTEM: You are "OxiAI", a logical, personal assistant that answers *only* via valid, minified, UTF-8 JSON."#;
let header_prompt = r#"SYSTEM: You are "OxiAI", a logical, personal assistant that answers *only* via valid, minified, UTF-8 JSON."#;
let tools_list = include_str!("tool/tools_list.json")
let tools_list = include_str!("data/tools_list.json")
.parse::<serde_json::Value>()?
.to_string();
@@ -94,23 +87,6 @@ async fn main() -> anyhow::Result<()> {
6. Base claims strictly on provided data or tool results. If unsure, say so.
7. Check your output; If you reach four consecutive newlines: *stop*"#;
let example_prompt = format!(
"Example 1:{user_q_1}\n{assistant_tool_request_1}\n{tool_result_1}\n{assistant_a_1}",
user_q_1 = r#"user: {"action":"chat", "arguments":{"response":"Provide a summary of the American Crow.", "source":"user"}}"#,
assistant_tool_request_1 = format!(
"assistant: {{ \"action\":\"wiki_search\",\"arguments\":{{\"query\":\"American Crow\"}} }}"
),
tool_result_1 = format!(
"tool: {{ \"action\":\"wiki_search\",\"arguments\":{{\"result\":\"{search_data}\"}} }}",
search_data = include_str!("tool/american_crow_wikipedia.md").to_string()
),
assistant_a_1 = format!(
"assistant: {{ \"action\":\"chat\",\"arguments\":{{\"response\":\"{example1_assistant_message}\"}} }}",
example1_assistant_message =
include_str!("tool/american_crow_example1_message.md").to_string()
)
);
//let user_info_prompt = r#""#;
let system_prompt = format!(
"{header_prompt}\n
@@ -119,7 +95,7 @@ async fn main() -> anyhow::Result<()> {
);
loop {
terminal.draw(|f| chat_ui(f, &app))?;
terminal.draw(|f| ui::chat_ui(f, &app))?;
if event::poll(Duration::from_millis(100))? {
if let Event::Key(key) = event::read()? {
@@ -137,8 +113,9 @@ async fn main() -> anyhow::Result<()> {
app.messages.push(chat::Message::new(
chat::MessageRoles::User,
chat::Action::ChatMessage,
message_args));
chat::Action::Chat,
message_args,
));
let mut prompts = vec![chat::Prompt {
role: Cow::Borrowed("system"),
@@ -236,121 +213,96 @@ async fn batch_ollama_response<'a>(
client: Client,
req: chat::ChatRequest<'a>,
) -> anyhow::Result<()> {
let start = Instant::now();
let resp = client
.post("http://localhost:11434/api/chat")
.json(&req)
.send()
.await?;
let elapsed = start.elapsed();
batch_ollama_response_inner(app, client, req).await
}
let status = resp.status();
let headers = resp.headers().clone();
let body_bytes = resp.bytes().await?;
match serde_json::from_slice::<chat::ChatResponse>(&body_bytes) {
Ok(r) => app.messages.push(r.message),
Err(e) => {
println!("Failed to parse JSON: {}", e);
println!("Status: {}", status);
println!("Headers: {:#?}", headers);
// Try to print the body as text for debugging
if let Ok(body_text) = std::str::from_utf8(&body_bytes) {
println!("Body text: {}", body_text);
} else {
println!("Body was not valid UTF-8");
fn batch_ollama_response_inner<'a>(
app: &'a mut App,
client: Client,
req: chat::ChatRequest<'a>,
) -> Pin<Box<dyn Future<Output = anyhow::Result<()>> + Send + 'a>> {
Box::pin(async move {
let start = Instant::now();
let resp = client
.post("http://localhost:11434/api/chat")
.json(&req)
.send()
.await?;
let elapsed = start.elapsed();
let status = resp.status();
let headers = resp.headers().clone();
let body_bytes = resp.bytes().await?;
match serde_json::from_slice::<chat::ChatResponse>(&body_bytes) {
Ok(r) => {
match r.message.content.action {
chat::Action::Chat => app.messages.push(r.message),
chat::Action::Tool(assistant_tool) => {
match assistant_tool {
chat::AssistantTool::WikiSearch => {
//HACK: fake it for now, until I figure out how to grab a web page and display it in a way the model understands
let tool_args = r.message.content.arguments.clone();
app.messages.push(r.message);
let search_term = match tool_args.get("query") {
Some(v) => v.as_str(),
None => todo!(),
};
let tool_response = match search_term {
"American Crow" => {
let r = args_builder! {
"result" => include_str!("data/american_crow_wikipedia.md")
};
r
}
"Black Bear" => {
let r = args_builder! {
"result" => include_str!("data/black_bear_wikipedia.md")
};
r
}
_ => {
let r = args_builder! {
"result" => "Search failed to return any valid data"
};
r
}
};
let tool_message = Message::from((
chat::MessageRoles::Tool,
Action::Tool(chat::AssistantTool::WikiSearch),
tool_response,
));
app.messages.push(tool_message);
//FIXME: model could recurse forever
batch_ollama_response(app, client.clone(), req).await?;
}
chat::AssistantTool::WebSearch => todo!(),
chat::AssistantTool::GetDateTime => todo!(),
chat::AssistantTool::GetDirectoryTree => todo!(),
chat::AssistantTool::GetFileContents => todo!(),
chat::AssistantTool::InvalidTool => todo!(),
}
}
}
}
Err(e) => {
println!("Failed to parse JSON: {}", e);
println!("Status: {}", status);
println!("Headers: {:#?}", headers);
// Try to print the body as text for debugging
if let Ok(body_text) = std::str::from_utf8(&body_bytes) {
println!("Body text: {}", body_text);
} else {
println!("Body was not valid UTF-8");
}
}
}
}
/*
if app.args.nerd_stats {
app.messages.push(format!(
"System : Response generated via {} model with timestamp {}",
resp.model, resp.created_at
));
app.messages.push(format!(
"System : done_reason = {}, done = {}",
resp.done_reason, resp.done
));
app.messages
.push(format!("System : Response timing statistics..."));
app.messages
.push(format!("System : Total elapsed wall time: {:.2?}", elapsed));
app.messages.push(format!(
"System : Prompt tokens: {}",
resp.prompt_eval_count
));
app.messages.push(format!(
"System : Prompt eval duration: {} ns",
resp.prompt_eval_duration
));
app.messages
.push(format!("System : Output tokens: {}", resp.eval_count));
app.messages.push(format!(
"System : Output eval duration: {} ns",
resp.eval_duration
));
app.messages.push(format!(
"System : Model 'warm up' time {}",
(resp.total_duration - (resp.prompt_eval_duration + resp.eval_duration))
));
let token_speed = resp.eval_count as f64 / (resp.eval_duration as f64 / 1_000_000_000.0);
app.messages.push(format!(
"System > Output generation speed: {:.2} tokens/sec",
token_speed
));
}
*/
app.waiting = false;
Ok(())
}
fn chat_ui(f: &mut ratatui::Frame, app: &App) {
let chunks = Layout::default()
.direction(Direction::Vertical)
.margin(1)
.constraints([Constraint::Min(1), Constraint::Length(3)].as_ref())
.split(f.area());
let messages: Vec<Line> = app
.messages
.iter()
.map(|m| Line::from(Span::raw(m.to_string())))
.collect();
let messages_block = Paragraph::new(ratatui::text::Text::from(messages))
.block(Block::default().borders(Borders::ALL).title("Chat"))
.wrap(ratatui::widgets::Wrap { trim: true })
.scroll((
app.messages
.len()
.saturating_sub((chunks[0].height - 2) as usize) as u16,
0,
));
f.render_widget(messages_block, chunks[0]);
let input_text = if app.waiting {
format!("> {} (waiting...)", &app.prompt)
} else {
format!("> {}", app.prompt)
};
let input = Paragraph::new(input_text)
.style(Style::default().fg(Color::Yellow))
.block(Block::default().borders(Borders::ALL).title("Input"));
f.render_widget(input, chunks[1]);
use ratatui::layout::Position;
f.set_cursor_position(Position::new(
// the +3 comes from the 3 'characters' of space between the terminal edge and the text location
// this places the text cursor after the last entered character
chunks[1].x + app.prompt.len() as u16 + 3,
chunks[1].y + 1,
));
app.waiting = false;
Ok(())
})
}