tons of changes, mostly type checking and prompt clean up

This commit is contained in:
2025-04-28 17:32:14 -04:00
parent 27382af199
commit 950d09779c
7 changed files with 425 additions and 121 deletions

View File

@@ -10,7 +10,6 @@ anyhow = "1.0"
reqwest = { version = "0.12", features = ["json", "stream"] }
serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0"
tokio = { version = "1.44", features = ["full"] }
chrono = "0.4"
clap = { version = "4.5", features = ["derive"] }

200
src/chat/mod.rs Normal file
View File

@@ -0,0 +1,200 @@
use serde::{Deserialize, Serialize};
use serde::de::{self, Deserializer};
use std::borrow::Cow;
use std::collections::HashMap;
use std::fmt::{Display, Formatter, Result as FmtResult};
#[derive(Deserialize, Debug)]
pub struct StreamChunk {
pub message: StreamMessage,
}
#[derive(Deserialize, Debug)]
pub struct StreamMessage {
pub role: String,
pub content: String,
}
#[derive(Serialize, Deserialize, Debug)]
pub struct Prompt<'a> {
pub role: Cow<'a, str>,
pub content: Cow<'a, str>,
}
impl<'a> From<Message> for Prompt<'a> {
fn from(message: Message) -> Self {
Prompt {
role: Cow::Owned(message.role),
content: Cow::Owned(message.content.to_string()),
}
}
}
#[derive(Serialize, Debug)]
pub struct ChatOptions {
pub temperature: Option<f32>,
pub top_p: Option<f32>,
pub top_k: Option<u32>,
pub repeat_penalty: Option<f32>,
pub seed: Option<u32>,
}
#[derive(Serialize, Debug)]
pub struct ChatRequest<'a> {
pub model: &'a str,
pub messages: Vec<Prompt<'a>>,
pub stream: bool,
pub format: &'a str,
pub stop: Vec<&'a str>,
pub options: Option<ChatOptions>,
}
pub enum MessageRoles {
System = 0,
Tool,
User,
Assistant,
Other,
}
impl Display for MessageRoles {
fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult {
let role: &str = match self {
MessageRoles::System => "system",
MessageRoles::Tool => "tool",
MessageRoles::User => "user",
MessageRoles::Assistant => "assistant",
//HACK: Handle this cleanly, if the model hallucinates we crash :^)
MessageRoles::Other => todo!(),
};
write!(f, "{}", role)
}
}
#[derive(Serialize, Deserialize, Clone, PartialEq)]
pub struct Message {
pub role: String,
#[serde(deserialize_with = "Message::de_content")]
pub content: ActionPacket,
}
impl Message {
pub fn new(role: MessageRoles, action: Action, arguments: HashMap<String, String>) -> Self {
Self {
role: role.to_string(),
content: ActionPacket::new(action, arguments),
}
}
// Custom deserializer function
fn de_content<'de, D>(deserializer: D) -> Result<ActionPacket, D::Error>
where
D: Deserializer<'de>,
{
let s = String::deserialize(deserializer)?;
serde_json::from_str(&s).map_err(de::Error::custom)
}
}
impl From<(MessageRoles, Action, HashMap<String, String>)> for Message {
fn from((role, action, arguments): (MessageRoles, Action, HashMap<String, String>)) -> Self {
Message::new(role, action, arguments)
}
}
impl Display for Message {
fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult {
write!(f, "{}", self.content)
}
}
#[derive(Serialize, Deserialize, PartialEq)]
pub enum AssistantTool {
WikiSearch,
WebSearch,
GetDateTime,
GetDirectoryTree,
GetFileContents,
InvalidTool,
}
impl Display for AssistantTool {
fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult {
let res = match self {
AssistantTool::WikiSearch => "wiki_search",
AssistantTool::WebSearch => "web_search",
AssistantTool::GetDateTime => "get_datetime",
AssistantTool::GetDirectoryTree => "get_dirtree",
AssistantTool::GetFileContents => "get_file",
//HACK: Handle this cleanly, if the model hallucinates we crash :^)
AssistantTool::InvalidTool => todo!(),
};
write!(f, "{}", res)
}
}
#[derive(Serialize, Deserialize, PartialEq)]
pub enum Action {
ChatMessage,
Tool(AssistantTool),
}
impl Display for Action {
fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult {
match self {
Action::ChatMessage => write!(f, "{}", "chat"),
Action::Tool(tool_name) => write!(f, "{tool_name}"),
}
}
}
#[derive(Serialize, Deserialize, Clone, PartialEq)]
pub struct ActionPacket {
pub action: String,
pub arguments: HashMap<String, String>,
}
impl ActionPacket {
pub fn new(action: Action, arguments: HashMap<String, String>) -> Self {
Self {
action: action.to_string(),
arguments,
}
}
}
impl Display for ActionPacket {
fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult {
match serde_json::to_string(&self.arguments) {
Ok(arguments_json) => write!(f, "{} {}", self.action, arguments_json),
Err(_) => write!(f, "{} {{}}", self.action), // fallback to empty JSON if error
}
}
}
#[derive(Deserialize)]
pub struct ChatResponse {
pub model: String,
pub created_at: String,
pub message: Message,
pub done: bool,
pub done_reason: Option<String>,
pub total_duration: Option<u64>,
pub eval_count: Option<u64>,
pub eval_duration: Option<u64>,
pub prompt_eval_count: Option<u64>,
pub prompt_eval_duration: Option<u64>,
}
#[macro_export]
macro_rules! args_builder {
( $( $key:expr => $value:expr ),* $(,)? ) => {{
let mut map = ::std::collections::HashMap::new();
$(
map.insert($key.into(), $value.into());
)*
map
}};
}

View File

@@ -1,9 +1,10 @@
use std::borrow::Cow;
use std::time::{Duration, Instant};
use chat::Message;
use clap::Parser;
use futures_util::StreamExt;
use reqwest::Client;
use serde::{Deserialize, Serialize};
use crossterm::event::{self, DisableMouseCapture, EnableMouseCapture, Event, KeyCode};
use crossterm::terminal::{
@@ -18,13 +19,16 @@ use ratatui::{
text::{Line, Span},
widgets::{Block, Borders, Paragraph},
};
use serde::{Deserialize, Serialize};
mod chat;
#[derive(Parser)]
struct Args {
#[arg(
short,
long,
default_value = "mixtral:8x7b-instruct-v0.1-q5_K_M",
default_value = "mistral:latest",
help = "Model name to use"
)]
model: String,
@@ -35,60 +39,29 @@ struct Args {
help = "Should the response be streamed from ollama or sent all at once"
)]
stream: bool,
}
#[derive(Deserialize, Debug)]
struct StreamChunk {
message: StreamMessage,
}
#[derive(Deserialize, Debug)]
struct StreamMessage {
role: String,
content: String,
}
#[derive(Serialize, Deserialize)]
struct Prompt<'a> {
role: &'a str,
content: &'a str,
}
#[derive(Serialize)]
struct ChatRequest<'a> {
model: &'a str,
messages: Vec<Prompt<'a>>,
stream: bool,
}
#[derive(Serialize, Deserialize)]
struct Message {
role: String,
content: String,
}
#[derive(Deserialize)]
struct ChatResponse {
model: String,
created_at: String,
message: Message,
done_reason: String,
done: bool,
total_duration: u64,
eval_count: u64,
eval_duration: u64,
prompt_eval_count: u64,
prompt_eval_duration: u64,
#[arg(short, long, help = "Show statistics in non-stream mode?")]
nerd_stats: bool,
}
struct App {
args: Args,
prompt: String,
messages: Vec<String>,
waiting: bool
messages: Vec<Message>,
waiting: bool,
}
#[tokio::main]
async fn main() -> anyhow::Result<()> {
// parse arguments
let args = match Args::try_parse() {
Ok(args) => args,
Err(e) => {
e.print().expect("Error writing clap error");
std::process::exit(0);
}
};
// setup crossterm
enable_raw_mode()?;
let mut stdout = std::io::stdout();
@@ -97,19 +70,54 @@ async fn main() -> anyhow::Result<()> {
let mut terminal = Terminal::new(backend)?;
let mut app = App {
args,
prompt: String::new(),
messages: vec![String::from("Welcome to the OxiAI TUI Interface!")],
waiting: false
messages: vec![],
waiting: false,
};
// parse arguments
let args = Args::parse();
let client = Client::new();
let model_name = &args.model;
let system_prompt =
"[INST]You are a helpful, logical and extremely technical AI assistant.[INST]";
let header_prompt =
r#"SYSTEM: You are "OxiAI", a logical, personal assistant that answers *only* via JSON"#;
let tools_list = include_str!("tool/tools_list.json")
.parse::<serde_json::Value>()?
.to_string();
let rules_prompt = r#"Rules:
1. Think silently, Never reveal your chain-of-thought.
2. To use a tool: {"action":"<tool>","arguments":{...}}
3. To reply directly: {"action":"chat","arguments":{"response":"..."}
4. If a question is vague, comparative, descriptive, or about ideas rather than specifics: use the web_search tool.
5. If a question clearly names a specific object, animal, person, place: use the wiki_search tool.
6. Base claims strictly on provided data or tool results. If unsure, say so.
7. Perform a JSON Self-check to ensure valid, minified, UTF-8 JSON.
8. Finish with a coherent sentence; if you reach four consecutive newlines: **STOP.**"#;
let example_prompt = format!(
"Example 1:{user_q_1}\n{assistant_tool_request_1}\n{tool_result_1}\n{assistant_a_1}",
user_q_1 = r#"user: {"action":"chat", "arguments":{"response":"Provide a summary of the American Crow.", "source":"user"}}"#,
assistant_tool_request_1 = format!(
"assistant: {{ \"action\":\"wiki_search\",\"arguments\":{{\"query\":\"American Crow\"}} }}"
),
tool_result_1 = format!(
"tool: {{ \"action\":\"wiki_search\",\"arguments\":{{\"result\":\"{search_data}\"}} }}",
search_data = include_str!("tool/american_crow_wikipedia.md").to_string()
),
assistant_a_1 = format!(
"assistant: {{ \"action\":\"chat\",\"arguments\":{{\"response\":\"{example1_assistant_message}\"}} }}",
example1_assistant_message =
include_str!("tool/american_crow_example1_message.md").to_string()
)
);
//let user_info_prompt = r#""#;
let system_prompt = format!(
"{header_prompt}\n
{tools_list}\n\n
{rules_prompt}\n"
);
loop {
terminal.draw(|f| chat_ui(f, &app))?;
@@ -123,36 +131,49 @@ async fn main() -> anyhow::Result<()> {
}
KeyCode::Enter => {
//TODO: refactor to a parser function to take the contents of the app.prompt vec and do fancy stuff with it (like commands)
let prompt = app.prompt.clone();
app.messages.push(format!("[INST]{}[INST]", prompt));
let message_args = args_builder! {
"response" => app.prompt.clone(),
};
app.prompt.clear();
let user_prompt = app.messages.pop()
.expect("No user prompt received (empty user_prompt)");
app.messages.push(chat::Message::new(
chat::MessageRoles::User,
chat::Action::ChatMessage,
message_args));
let req = ChatRequest {
model: model_name,
stream: args.stream,
messages: vec![
Prompt {
role: "system",
content: system_prompt,
},
Prompt {
role: "user",
content: &user_prompt,
},
],
let mut prompts = vec![chat::Prompt {
role: Cow::Borrowed("system"),
content: Cow::Borrowed(&system_prompt),
}];
prompts.extend(
app.messages
.iter()
.map(|msg| chat::Prompt::from(msg.clone())),
);
let req = chat::ChatRequest {
model: &app.args.model.clone(),
stream: app.args.stream,
format: "json",
stop: vec!["\n\n\n\n"],
options: Some(chat::ChatOptions {
temperature: Some(0.3),
top_p: Some(0.92),
top_k: Some(50),
repeat_penalty: Some(1.1),
seed: None,
}),
messages: prompts,
};
match args.stream {
app.waiting = true;
match app.args.stream {
true => {
stream_ollama_response(&mut app, client.clone(), req)
.await?;
todo!();
stream_ollama_response(&mut app, client.clone(), req).await?;
}
false => {
ollama_response(&mut app, client.clone(), req)
.await?;
batch_ollama_response(&mut app, client.clone(), req).await?;
}
}
}
@@ -175,12 +196,12 @@ async fn main() -> anyhow::Result<()> {
Ok(())
}
//FIXME: streaming replies are harder to work with for now, save this for the future
async fn stream_ollama_response(
app: &mut App,
client: Client,
req: ChatRequest<'_>,
req: chat::ChatRequest<'_>,
) -> anyhow::Result<()> {
app.waiting = true;
let mut resp = client
.post("http://localhost:11434/api/chat")
.json(&req)
@@ -198,60 +219,95 @@ async fn stream_ollama_response(
if line.is_empty() {
continue;
}
let parsed: serde_json::Result<StreamChunk> = serde_json::from_slice(line);
let parsed: serde_json::Result<chat::StreamChunk> = serde_json::from_slice(line);
if let Ok(parsed) = parsed {
assistant_line.push_str(&parsed.message.content);
assistant_line.push_str(&parsed.message.content.to_string());
}
}
}
app.messages.push(assistant_line);
app.waiting = false;
//FIXME: fix this later
//app.messages.push(assistant_line);
Ok(())
}
async fn ollama_response<'a>(
async fn batch_ollama_response<'a>(
app: &mut App,
client: Client,
req: ChatRequest<'a>,
req: chat::ChatRequest<'a>,
) -> anyhow::Result<()> {
app.waiting = true;
let start = Instant::now();
let resp: ChatResponse = client
let resp = client
.post("http://localhost:11434/api/chat")
.json(&req)
.send()
.await?
.json()
.await?;
let elapsed = start.elapsed();
let status = resp.status();
let headers = resp.headers().clone();
let body_bytes = resp.bytes().await?;
app.messages.push(format!("{} : {}", resp.message.role, resp.message.content));
app.messages.push(format!("System : Response generated via {} model with timestamp {}",
match serde_json::from_slice::<chat::ChatResponse>(&body_bytes) {
Ok(r) => app.messages.push(r.message),
Err(e) => {
println!("Failed to parse JSON: {}", e);
println!("Status: {}", status);
println!("Headers: {:#?}", headers);
// Try to print the body as text for debugging
if let Ok(body_text) = std::str::from_utf8(&body_bytes) {
println!("Body text: {}", body_text);
} else {
println!("Body was not valid UTF-8");
}
}
}
/*
if app.args.nerd_stats {
app.messages.push(format!(
"System : Response generated via {} model with timestamp {}",
resp.model, resp.created_at
));
app.messages.push(format!("System : done_reason = {}, done = {}",
app.messages.push(format!(
"System : done_reason = {}, done = {}",
resp.done_reason, resp.done
));
app.messages.push(format!("System : Response timing statistics..."));
app.messages
.push(format!("System : Response timing statistics..."));
app.messages.push(format!("System : Total elapsed wall time: {:.2?}", elapsed));
app.messages.push(format!("System : Prompt tokens: {}", resp.prompt_eval_count));
app.messages.push(format!("System : Prompt eval duration: {} ns", resp.prompt_eval_duration));
app.messages.push(format!("System : Output tokens: {}", resp.eval_count));
app.messages.push(format!("System : Output eval duration: {} ns", resp.eval_duration));
app.messages.push(format!("System : Model 'warm up' time {}", (resp.total_duration - (resp.prompt_eval_duration + resp.eval_duration))));
app.messages
.push(format!("System : Total elapsed wall time: {:.2?}", elapsed));
app.messages.push(format!(
"System : Prompt tokens: {}",
resp.prompt_eval_count
));
app.messages.push(format!(
"System : Prompt eval duration: {} ns",
resp.prompt_eval_duration
));
app.messages
.push(format!("System : Output tokens: {}", resp.eval_count));
app.messages.push(format!(
"System : Output eval duration: {} ns",
resp.eval_duration
));
app.messages.push(format!(
"System : Model 'warm up' time {}",
(resp.total_duration - (resp.prompt_eval_duration + resp.eval_duration))
));
let token_speed = resp.eval_count as f64 / (resp.eval_duration as f64 / 1_000_000_000.0);
app.messages.push(format!("System > Output generation speed: {:.2} tokens/sec", token_speed));
app.messages.push(format!(
"System > Output generation speed: {:.2} tokens/sec",
token_speed
));
}
*/
app.waiting = false;
Ok(())
}
@@ -265,7 +321,7 @@ fn chat_ui(f: &mut ratatui::Frame, app: &App) {
let messages: Vec<Line> = app
.messages
.iter()
.map(|m| Line::from(Span::raw(m.clone())))
.map(|m| Line::from(Span::raw(m.to_string())))
.collect();
let messages_block = Paragraph::new(ratatui::text::Text::from(messages))
@@ -292,13 +348,10 @@ fn chat_ui(f: &mut ratatui::Frame, app: &App) {
f.render_widget(input, chunks[1]);
use ratatui::layout::Position;
f.set_cursor_position(
Position::new(
f.set_cursor_position(Position::new(
// the +3 comes from the 3 'characters' of space between the terminal edge and the text location
// this places the text cursor after the last entered character
chunks[1].x + app.prompt.len() as u16 + 3,
chunks[1].y + 1,
)
);
));
}

View File

@@ -0,0 +1 @@
The American crow is a large passerine bird species native to North America, known for its black plumage and distinct calls. Adults typically measure 40-50 cm (16-20 in) from beak to tail and weigh between 300-600 g (11-21 oz), with males slightly larger than females. They are highly intelligent, adaptable, and often found in various human environments. The American crow can be distinguished from other similar corvid species by its size, beak shape, and call patterns. This bird is a common and widespread species across North America, serving as a useful bioindicator for diseases such as the West Nile virus.

View File

@@ -0,0 +1 @@
The American crow (Corvus brachyrhynchos) is a large passerine bird species of the family Corvidae. It is a common bird found throughout much of North America. American crows are the New World counterpart to the carrion crow and the hooded crow of Eurasia; they all occupy the same ecological niche. Although the American crow and the hooded crow are very similar in size, structure and behavior, their calls and visual appearance are different. From beak to tail, an American crow measures 4050 cm (1620 in), almost half of which is tail. Its wingspan is 85100 cm (3339 in). Mass varies from about 300 to 600 g (11 to 21 oz), with males tending to be larger than females. Plumage is all black, with iridescent feathers. It looks much like other all-black corvids. They are very intelligent, and adaptable to human environments. The most usual call is CaaW!-CaaW!-CaaW! They can be distinguished from the common raven (C. corax) because American crows are smaller and the beak is slightly less pronounced; from the fish crow (C. ossifragus) because American crows do not hunch and fluff their throat feathers when they call; and from the carrion crow (C. corone) by size, as the carrion crow is larger and of a stockier build. American crows are common, widespread, and susceptible to the West Nile virus, making them useful as a bioindicator to track the virus's spread. Direct transmission of the virus from crows to humans is impossible.

View File

@@ -0,0 +1 @@
The American black bear (Ursus americanus), or simply black bear, is a species of medium-sized bear endemic to North America. It is the continent's smallest and most widely distributed bear species. It is an omnivore, with a diet varying greatly depending on season and location. It typically lives in largely forested areas but will leave forests in search of food and is sometimes attracted to human communities due to the immediate availability of food. The International Union for Conservation of Nature (IUCN) lists the American black bear as a least-concern species because of its widespread distribution and a large population, estimated to be twice that of all other bear species combined. Along with the brown bear (Ursus arctos), it is one of only two modern bear species not considered by the IUCN to be globally threatened with extinction. Taxonomy and evolution The American black bear is not closely related to the brown bear or polar bear, though all three species are found in North America; genetic studies reveal that they split from a common ancestor 5.05 million years ago (mya). American and Asian black bears are considered sister taxa and are more closely related to each other than to the other modern species of bears. According to recent studies, the sun bear is also a relatively recent split from this lineage. A small primitive bear called Ursus abstrusus is the oldest known North American fossil member of the genus Ursus, dated to 4.95 mya. This suggests that U. abstrusus may be the direct ancestor of the American black bear, which evolved in North America. Although Wolverton and Lyman still consider U. vitabilis an \"apparent precursor to modern black bears\", it has also been placed within U. americanus. The ancestors of American black bears and Asian black bears diverged from sun bears 4.58 mya. The American black bear then split from the Asian black bear 4.08 mya. The earliest American black bear fossils, which were located in Port Kennedy, Pennsylvania, greatly resemble the Asian species, though later specimens grew to sizes comparable to grizzly bears. From the Holocene to the present, American black bears seem to have shrunk in size, but this has been disputed because of problems with dating these fossil specimens. The American black bear lived during the same period as the giant and lesser short-faced bears (Arctodus simus and A. pristinus, respectively) and the Florida spectacled bear (Tremarctos floridanus). These tremarctine bears evolved from bears that had emigrated from Asia to the Americas 78 mya. The giant and lesser short-faced bears are thought to have been heavily carnivorous and the Florida spectacled bear more herbivorous, while the American black bears remained arboreal omnivores, like their Asian ancestors. The American black bear's generalist behavior allowed it to exploit a wider variety of foods and has been given as a reason why, of these three genera, it alone survived climate and vegetative changes through the last Ice Age while the other, more specialized North American predators became extinct. However, both Arctodus and Tremarctos had survived several other, previous ice ages. After these prehistoric ursids became extinct during the last glacial period 10,000 years ago, American black bears were probably the only bear present in much of North America until the migration of brown bears to the rest of the continent.

49
src/tool/tools_list.json Normal file
View File

@@ -0,0 +1,49 @@
{
"tools": [
{
"type": "function",
"function": {
"name": "wiki_search",
"description": "Search Wikipedia",
"parameters": {
"type": "object",
"properties": {
"query": {
"type": "string",
"description": "Search term to request"
}
}
},
"required": ["query"]
}
},
{
"type": "function",
"function": {
"name": "web_search",
"description": "Search DuckDuckGo (a web search engine)",
"parameters": {
"type": "object",
"properties": {
"query": {
"type": "string",
"description": "Search term to request"
}
}
},
"required": ["query"]
}
},
{
"type": "function",
"function": {
"name": "get_datetime_iso8601",
"description": "Get the current date and time in iso8601 format to the seconds",
"parameters": {
"type": "None"
},
"required": []
}
}
]
}