tons of changes, mostly type checking and prompt clean up

This commit is contained in:
2025-04-28 17:32:14 -04:00
parent 27382af199
commit 950d09779c
7 changed files with 425 additions and 121 deletions

View File

@@ -1,9 +1,10 @@
use std::borrow::Cow;
use std::time::{Duration, Instant};
use chat::Message;
use clap::Parser;
use futures_util::StreamExt;
use reqwest::Client;
use serde::{Deserialize, Serialize};
use crossterm::event::{self, DisableMouseCapture, EnableMouseCapture, Event, KeyCode};
use crossterm::terminal::{
@@ -18,13 +19,16 @@ use ratatui::{
text::{Line, Span},
widgets::{Block, Borders, Paragraph},
};
use serde::{Deserialize, Serialize};
mod chat;
#[derive(Parser)]
struct Args {
#[arg(
short,
long,
default_value = "mixtral:8x7b-instruct-v0.1-q5_K_M",
default_value = "mistral:latest",
help = "Model name to use"
)]
model: String,
@@ -35,60 +39,29 @@ struct Args {
help = "Should the response be streamed from ollama or sent all at once"
)]
stream: bool,
}
#[derive(Deserialize, Debug)]
struct StreamChunk {
message: StreamMessage,
}
#[derive(Deserialize, Debug)]
struct StreamMessage {
role: String,
content: String,
}
#[derive(Serialize, Deserialize)]
struct Prompt<'a> {
role: &'a str,
content: &'a str,
}
#[derive(Serialize)]
struct ChatRequest<'a> {
model: &'a str,
messages: Vec<Prompt<'a>>,
stream: bool,
}
#[derive(Serialize, Deserialize)]
struct Message {
role: String,
content: String,
}
#[derive(Deserialize)]
struct ChatResponse {
model: String,
created_at: String,
message: Message,
done_reason: String,
done: bool,
total_duration: u64,
eval_count: u64,
eval_duration: u64,
prompt_eval_count: u64,
prompt_eval_duration: u64,
#[arg(short, long, help = "Show statistics in non-stream mode?")]
nerd_stats: bool,
}
struct App {
args: Args,
prompt: String,
messages: Vec<String>,
waiting: bool
messages: Vec<Message>,
waiting: bool,
}
#[tokio::main]
async fn main() -> anyhow::Result<()> {
// parse arguments
let args = match Args::try_parse() {
Ok(args) => args,
Err(e) => {
e.print().expect("Error writing clap error");
std::process::exit(0);
}
};
// setup crossterm
enable_raw_mode()?;
let mut stdout = std::io::stdout();
@@ -97,19 +70,54 @@ async fn main() -> anyhow::Result<()> {
let mut terminal = Terminal::new(backend)?;
let mut app = App {
args,
prompt: String::new(),
messages: vec![String::from("Welcome to the OxiAI TUI Interface!")],
waiting: false
messages: vec![],
waiting: false,
};
// parse arguments
let args = Args::parse();
let client = Client::new();
let model_name = &args.model;
let system_prompt =
"[INST]You are a helpful, logical and extremely technical AI assistant.[INST]";
let header_prompt =
r#"SYSTEM: You are "OxiAI", a logical, personal assistant that answers *only* via JSON"#;
let tools_list = include_str!("tool/tools_list.json")
.parse::<serde_json::Value>()?
.to_string();
let rules_prompt = r#"Rules:
1. Think silently, Never reveal your chain-of-thought.
2. To use a tool: {"action":"<tool>","arguments":{...}}
3. To reply directly: {"action":"chat","arguments":{"response":"..."}
4. If a question is vague, comparative, descriptive, or about ideas rather than specifics: use the web_search tool.
5. If a question clearly names a specific object, animal, person, place: use the wiki_search tool.
6. Base claims strictly on provided data or tool results. If unsure, say so.
7. Perform a JSON Self-check to ensure valid, minified, UTF-8 JSON.
8. Finish with a coherent sentence; if you reach four consecutive newlines: **STOP.**"#;
let example_prompt = format!(
"Example 1:{user_q_1}\n{assistant_tool_request_1}\n{tool_result_1}\n{assistant_a_1}",
user_q_1 = r#"user: {"action":"chat", "arguments":{"response":"Provide a summary of the American Crow.", "source":"user"}}"#,
assistant_tool_request_1 = format!(
"assistant: {{ \"action\":\"wiki_search\",\"arguments\":{{\"query\":\"American Crow\"}} }}"
),
tool_result_1 = format!(
"tool: {{ \"action\":\"wiki_search\",\"arguments\":{{\"result\":\"{search_data}\"}} }}",
search_data = include_str!("tool/american_crow_wikipedia.md").to_string()
),
assistant_a_1 = format!(
"assistant: {{ \"action\":\"chat\",\"arguments\":{{\"response\":\"{example1_assistant_message}\"}} }}",
example1_assistant_message =
include_str!("tool/american_crow_example1_message.md").to_string()
)
);
//let user_info_prompt = r#""#;
let system_prompt = format!(
"{header_prompt}\n
{tools_list}\n\n
{rules_prompt}\n"
);
loop {
terminal.draw(|f| chat_ui(f, &app))?;
@@ -123,36 +131,49 @@ async fn main() -> anyhow::Result<()> {
}
KeyCode::Enter => {
//TODO: refactor to a parser function to take the contents of the app.prompt vec and do fancy stuff with it (like commands)
let prompt = app.prompt.clone();
app.messages.push(format!("[INST]{}[INST]", prompt));
let message_args = args_builder! {
"response" => app.prompt.clone(),
};
app.prompt.clear();
let user_prompt = app.messages.pop()
.expect("No user prompt received (empty user_prompt)");
app.messages.push(chat::Message::new(
chat::MessageRoles::User,
chat::Action::ChatMessage,
message_args));
let req = ChatRequest {
model: model_name,
stream: args.stream,
messages: vec![
Prompt {
role: "system",
content: system_prompt,
},
Prompt {
role: "user",
content: &user_prompt,
},
],
let mut prompts = vec![chat::Prompt {
role: Cow::Borrowed("system"),
content: Cow::Borrowed(&system_prompt),
}];
prompts.extend(
app.messages
.iter()
.map(|msg| chat::Prompt::from(msg.clone())),
);
let req = chat::ChatRequest {
model: &app.args.model.clone(),
stream: app.args.stream,
format: "json",
stop: vec!["\n\n\n\n"],
options: Some(chat::ChatOptions {
temperature: Some(0.3),
top_p: Some(0.92),
top_k: Some(50),
repeat_penalty: Some(1.1),
seed: None,
}),
messages: prompts,
};
match args.stream {
app.waiting = true;
match app.args.stream {
true => {
stream_ollama_response(&mut app, client.clone(), req)
.await?;
todo!();
stream_ollama_response(&mut app, client.clone(), req).await?;
}
false => {
ollama_response(&mut app, client.clone(), req)
.await?;
batch_ollama_response(&mut app, client.clone(), req).await?;
}
}
}
@@ -175,12 +196,12 @@ async fn main() -> anyhow::Result<()> {
Ok(())
}
//FIXME: streaming replies are harder to work with for now, save this for the future
async fn stream_ollama_response(
app: &mut App,
client: Client,
req: ChatRequest<'_>,
req: chat::ChatRequest<'_>,
) -> anyhow::Result<()> {
app.waiting = true;
let mut resp = client
.post("http://localhost:11434/api/chat")
.json(&req)
@@ -198,60 +219,95 @@ async fn stream_ollama_response(
if line.is_empty() {
continue;
}
let parsed: serde_json::Result<StreamChunk> = serde_json::from_slice(line);
let parsed: serde_json::Result<chat::StreamChunk> = serde_json::from_slice(line);
if let Ok(parsed) = parsed {
assistant_line.push_str(&parsed.message.content);
assistant_line.push_str(&parsed.message.content.to_string());
}
}
}
app.messages.push(assistant_line);
app.waiting = false;
//FIXME: fix this later
//app.messages.push(assistant_line);
Ok(())
}
async fn ollama_response<'a>(
async fn batch_ollama_response<'a>(
app: &mut App,
client: Client,
req: ChatRequest<'a>,
req: chat::ChatRequest<'a>,
) -> anyhow::Result<()> {
app.waiting = true;
let start = Instant::now();
let resp: ChatResponse = client
let resp = client
.post("http://localhost:11434/api/chat")
.json(&req)
.send()
.await?
.json()
.await?;
let elapsed = start.elapsed();
let status = resp.status();
let headers = resp.headers().clone();
let body_bytes = resp.bytes().await?;
match serde_json::from_slice::<chat::ChatResponse>(&body_bytes) {
Ok(r) => app.messages.push(r.message),
Err(e) => {
println!("Failed to parse JSON: {}", e);
println!("Status: {}", status);
println!("Headers: {:#?}", headers);
// Try to print the body as text for debugging
if let Ok(body_text) = std::str::from_utf8(&body_bytes) {
println!("Body text: {}", body_text);
} else {
println!("Body was not valid UTF-8");
}
}
}
app.messages.push(format!("{} : {}", resp.message.role, resp.message.content));
app.messages.push(format!("System : Response generated via {} model with timestamp {}",
resp.model, resp.created_at
));
/*
if app.args.nerd_stats {
app.messages.push(format!(
"System : Response generated via {} model with timestamp {}",
resp.model, resp.created_at
));
app.messages.push(format!("System : done_reason = {}, done = {}",
resp.done_reason, resp.done
));
app.messages.push(format!(
"System : done_reason = {}, done = {}",
resp.done_reason, resp.done
));
app.messages.push(format!("System : Response timing statistics..."));
app.messages
.push(format!("System : Response timing statistics..."));
app.messages.push(format!("System : Total elapsed wall time: {:.2?}", elapsed));
app.messages.push(format!("System : Prompt tokens: {}", resp.prompt_eval_count));
app.messages.push(format!("System : Prompt eval duration: {} ns", resp.prompt_eval_duration));
app.messages.push(format!("System : Output tokens: {}", resp.eval_count));
app.messages.push(format!("System : Output eval duration: {} ns", resp.eval_duration));
app.messages.push(format!("System : Model 'warm up' time {}", (resp.total_duration - (resp.prompt_eval_duration + resp.eval_duration))));
let token_speed = resp.eval_count as f64 / (resp.eval_duration as f64 / 1_000_000_000.0);
app.messages.push(format!("System > Output generation speed: {:.2} tokens/sec", token_speed));
app.messages
.push(format!("System : Total elapsed wall time: {:.2?}", elapsed));
app.messages.push(format!(
"System : Prompt tokens: {}",
resp.prompt_eval_count
));
app.messages.push(format!(
"System : Prompt eval duration: {} ns",
resp.prompt_eval_duration
));
app.messages
.push(format!("System : Output tokens: {}", resp.eval_count));
app.messages.push(format!(
"System : Output eval duration: {} ns",
resp.eval_duration
));
app.messages.push(format!(
"System : Model 'warm up' time {}",
(resp.total_duration - (resp.prompt_eval_duration + resp.eval_duration))
));
let token_speed = resp.eval_count as f64 / (resp.eval_duration as f64 / 1_000_000_000.0);
app.messages.push(format!(
"System > Output generation speed: {:.2} tokens/sec",
token_speed
));
}
*/
app.waiting = false;
Ok(())
}
@@ -265,7 +321,7 @@ fn chat_ui(f: &mut ratatui::Frame, app: &App) {
let messages: Vec<Line> = app
.messages
.iter()
.map(|m| Line::from(Span::raw(m.clone())))
.map(|m| Line::from(Span::raw(m.to_string())))
.collect();
let messages_block = Paragraph::new(ratatui::text::Text::from(messages))
@@ -279,7 +335,7 @@ fn chat_ui(f: &mut ratatui::Frame, app: &App) {
));
f.render_widget(messages_block, chunks[0]);
let input_text = if app.waiting {
format!("> {} (waiting...)", &app.prompt)
} else {
@@ -292,13 +348,10 @@ fn chat_ui(f: &mut ratatui::Frame, app: &App) {
f.render_widget(input, chunks[1]);
use ratatui::layout::Position;
f.set_cursor_position(
Position::new(
// the +3 comes from the 3 'characters' of space between the terminal edge and the text location
// this places the text cursor after the last entered character
chunks[1].x + app.prompt.len() as u16 + 3,
chunks[1].y + 1,
)
);
f.set_cursor_position(Position::new(
// the +3 comes from the 3 'characters' of space between the terminal edge and the text location
// this places the text cursor after the last entered character
chunks[1].x + app.prompt.len() as u16 + 3,
chunks[1].y + 1,
));
}