From 590c5ce8d3f5f276815886d5ebe9cc2aac7a5162 Mon Sep 17 00:00:00 2001 From: Niko Matsakis Date: Fri, 19 Aug 2022 06:46:54 -0400 Subject: [PATCH] regular structs for stmt,expr / track span Docs are somewhat out of date. Tracking span should enable re-use of type-check results. --- book/src/tutorial/ir.md | 61 ++++- calc-example/calc/Cargo.toml | 1 + calc-example/calc/src/ir.rs | 36 ++- calc-example/calc/src/main.rs | 4 +- calc-example/calc/src/parser.rs | 396 +++++++++++++++++++++++++++----- 5 files changed, 415 insertions(+), 83 deletions(-) diff --git a/book/src/tutorial/ir.md b/book/src/tutorial/ir.md index cbc35197..98391a73 100644 --- a/book/src/tutorial/ir.md +++ b/book/src/tutorial/ir.md @@ -69,17 +69,49 @@ When you change the value of an input field, that increments a 'revision counter indicating that some inputs are different now. When we talk about a "revision" of the database, we are referring to the state of the database in between changes to the input values. -## Tracked structs +### Representing the parsed program -Next we will define a **tracked struct** to represent the functions in our input. +Next we will define a **tracked struct**. Whereas inputs represent the *start* of a computation, tracked structs represent intermediate values created during your computation. -In this case, we are going to parse the raw input program, and create a `Function` for each of the functions defined by the user. + +In this case, the parser is going to take in the `SourceProgram` struct that we saw and return a `Program` that represents the fully parsed program: + +```rust +{{#include ../../../calc-example/calc/src/ir.rs:program}} +``` + +Like with an input, the fields of a tracked struct are also stored in the database. +Unlike an input, those fields are immutable (they cannot be "set"), and salsa compares them across revisions to know when they have changed. +In this case, if parsing the input produced the same `Program` result (e.g., because the only change to the input was some trailing whitespace, perhaps), +then subsequent parts of the computation won't need to re-execute. +(We'll revisit the role of tracked structs in reuse more in future parts of the IR.) + +Apart from the fields being immutable, the API for working with a tracked struct is quite similar to an input: + +* You can create a new value by using `new`, but with a tracked struct, you only need an `&dyn` database, not `&mut` (e.g., `Program::new(&db, some_staements)`) +* You use a getter to read the value of a field, just like with an input (e.g., `my_func.statements(db)` to read the `statements` field). + * In this case, the field is tagged as `#[return_ref]`, which means that the getter will return a `&Vec`, instead of cloning the vector. + +## Representing functions + +We will also use a tracked struct to represent each function: +Next we will define a **tracked struct**. +Whereas inputs represent the *start* of a computation, tracked structs represent intermediate values created during your computation. + +The `Function` struct is going to be created by the parser to represent each of the functions defined by the user: ```rust {{#include ../../../calc-example/calc/src/ir.rs:functions}} ``` -Unlike with inputs, the fields of tracked structs are immutable once created. Otherwise, working with a tracked struct is quite similar to an input: +Like with an input, the fields of a tracked struct are also stored in the database. +Unlike an input, those fields are immutable (they cannot be "set"), and salsa compares them across revisions to know when they have changed. +If we had created some `Function` instance `f`, for example, we might find that `the f.body` field changes +because the user changed the definition of `f`. +This would mean that we have to re-execute those parts of the code that depended on `f.body` +(but not those parts of the code that depended on the body of *other* functions). + +Apart from the fields being immutable, the API for working with a tracked struct is quite similar to an input: * You can create a new value by using `new`, but with a tracked struct, you only need an `&dyn` database, not `&mut` (e.g., `Function::new(&db, some_name, some_args, some_body)`) * You use a getter to read the value of a field, just like with an input (e.g., `my_func.args(db)` to read the `args` field). @@ -116,14 +148,6 @@ let f2 = FunctionId::new(&db, "my_string".to_string()); assert_eq!(f1, f2); ``` -### Expressions and statements - -We'll also intern expressions and statements. This is convenient primarily because it allows us to have recursive structures very easily. Since we don't really need the "cheap equality comparison" aspect of interning, this isn't the most efficient choice, and many compilers would opt to represent expressions/statements in some other way. - -```rust -{{#include ../../../calc-example/calc/src/ir.rs:statements_and_expressions}} -``` - ### Interned ids are guaranteed to be consistent within a revision, but not across revisions (but you don't have to care) Interned ids are guaranteed not to change within a single revision, so you can intern things from all over your program and get back consistent results. @@ -134,3 +158,16 @@ just a different one than they saw in a previous revision. In other words, within a salsa computation, you can assume that interning produces a single consistent integer, and you don't have to think about it. If however you export interned identifiers outside the computation, and then change the inputs, they may not longer be valid or may refer to different values. +### Expressions and statements + +We'll won't use any special "salsa structs" for expressions and statements: + +```rust +{{#include ../../../calc-example/calc/src/ir.rs:statements_and_expressions}} +``` + +Since statements and expressions are not tracked, this implies that we are only attempting to get incremental re-use at the granularity of functions -- +whenever anything in a function body changes, we consider the entire function body dirty and re-execute anything that depended on it. +It usually makes sense to draw some kind of "reasonably coarse" boundary like this. + +One downside of the way we have set things up: we inlined the position into each of the structs. diff --git a/calc-example/calc/Cargo.toml b/calc-example/calc/Cargo.toml index a75c0082..85aa0f7e 100644 --- a/calc-example/calc/Cargo.toml +++ b/calc-example/calc/Cargo.toml @@ -6,6 +6,7 @@ edition = "2021" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] +derive-new = "0.5.9" salsa = { path = "../../components/salsa-2022", package = "salsa-2022" } ordered-float = "3.0" diff --git a/calc-example/calc/src/ir.rs b/calc-example/calc/src/ir.rs index 87d7eb67..0c27fbd5 100644 --- a/calc-example/calc/src/ir.rs +++ b/calc-example/calc/src/ir.rs @@ -1,3 +1,4 @@ +use derive_new::new; use ordered_float::OrderedFloat; use salsa::debug::DebugWithDb; @@ -23,18 +24,23 @@ pub struct FunctionId { } // ANCHOR_END: interned_ids -// ANCHOR: statements_and_expressions +// ANCHOR: program #[salsa::tracked] pub struct Program { + #[return_ref] statements: Vec, } +// ANCHOR_END: program -#[salsa::interned] +// ANCHOR: statements_and_expressions +#[derive(Eq, PartialEq, Debug, Hash, new)] pub struct Statement { + span: Span, + data: StatementData, } -#[derive(Eq, PartialEq, Clone, Hash)] +#[derive(Eq, PartialEq, Debug, Hash)] pub enum StatementData { /// Defines `fn () = ` Function(Function), @@ -42,15 +48,16 @@ pub enum StatementData { Print(Expression), } -#[salsa::interned] +#[derive(Eq, PartialEq, Debug, Hash, new)] pub struct Expression { - #[return_ref] + span: Span, + data: ExpressionData, } -#[derive(Eq, PartialEq, Clone, Hash)] +#[derive(Eq, PartialEq, Debug, Hash)] pub enum ExpressionData { - Op(Expression, Op, Expression), + Op(Box, Op, Box), Number(OrderedFloat), Variable(VariableId), Call(FunctionId, Vec), @@ -77,7 +84,7 @@ impl DebugWithDb for Function { impl DebugWithDb for Statement { fn fmt(&self, f: &mut std::fmt::Formatter<'_>, db: &dyn crate::Db) -> std::fmt::Result { - match self.data(db) { + match &self.data { StatementData::Function(a) => DebugWithDb::fmt(&a, f, db), StatementData::Print(a) => DebugWithDb::fmt(&a, f, db), } @@ -87,7 +94,7 @@ impl DebugWithDb for Statement { // ANCHOR: expression_debug_impl impl DebugWithDb for Expression { fn fmt(&self, f: &mut std::fmt::Formatter<'_>, db: &dyn crate::Db) -> std::fmt::Result { - match self.data(db) { + match &self.data { ExpressionData::Op(a, b, c) => f .debug_tuple("ExpressionData::Op") .field(&a.debug(db)) // use `a.debug(db)` for interned things @@ -149,11 +156,22 @@ impl DebugWithDb for Diagnostic { pub struct Function { #[id] name: FunctionId, + + name_span: Span, + args: Vec, + + #[return_ref] body: Expression, } // ANCHOR_END: functions +#[salsa::tracked] +pub struct Span { + pub start: usize, + pub end: usize, +} + // ANCHOR: diagnostic #[salsa::accumulator] pub struct Diagnostics(Diagnostic); diff --git a/calc-example/calc/src/main.rs b/calc-example/calc/src/main.rs index 96a7c709..5f92550c 100644 --- a/calc-example/calc/src/main.rs +++ b/calc-example/calc/src/main.rs @@ -5,11 +5,11 @@ pub struct Jar( crate::ir::Program, crate::ir::VariableId, crate::ir::FunctionId, - crate::ir::Expression, - crate::ir::Statement, crate::ir::Function, crate::ir::Diagnostics, + crate::ir::Span, crate::parser::parse_statements, + crate::type_check::type_check_program, ); // ANCHOR_END: jar_struct diff --git a/calc-example/calc/src/parser.rs b/calc-example/calc/src/parser.rs index 2236cbb9..0af2f8da 100644 --- a/calc-example/calc/src/parser.rs +++ b/calc-example/calc/src/parser.rs @@ -2,11 +2,11 @@ use ordered_float::OrderedFloat; use crate::ir::{ Diagnostic, Diagnostics, Expression, ExpressionData, Function, FunctionId, Op, Program, - SourceProgram, Statement, StatementData, VariableId, + SourceProgram, Span, Statement, StatementData, VariableId, }; // ANCHOR: parse_statements -#[salsa::tracked(return_ref)] +#[salsa::tracked] pub fn parse_statements(db: &dyn crate::Db, source: SourceProgram) -> Program { // Get the source text from the database let source_text = source.text(db); @@ -93,34 +93,44 @@ impl Parser<'_> { self.source_text[self.position..].chars().next() } + // Returns a span ranging from `start_position` until the current position (exclusive) + fn span_from(&self, start_position: usize) -> Span { + Span::new(self.db, start_position, self.position) + } + fn consume(&mut self, ch: char) { debug_assert!(self.peek() == Some(ch)); self.position += ch.len_utf8(); } - fn skip_whitespace(&mut self) { + /// Skips whitespace and returns the new position. + fn skip_whitespace(&mut self) -> usize { while let Some(ch) = self.peek() { if ch.is_whitespace() { self.consume(ch); } else { - return; + break; } } + self.position } // ANCHOR: parse_statement fn parse_statement(&mut self) -> Option { - self.skip_whitespace(); + let start_position = self.skip_whitespace(); let word = self.word()?; if word == "fn" { let func = self.parse_function()?; - Some(Statement::new(self.db, StatementData::Function(func))) - // ^^^^^^^^^^^^^^^ ^^^^^^^^^^^^^^^^^^^^^^^ - // Create a new interned enum... | - // using the "data" type. + Some(Statement::new( + self.span_from(start_position), + StatementData::Function(func), + )) } else if word == "print" { let expr = self.parse_expression()?; - Some(Statement::new(self.db, StatementData::Print(expr))) + Some(Statement::new( + self.span_from(start_position), + StatementData::Print(expr), + )) } else { None } @@ -129,7 +139,9 @@ impl Parser<'_> { // ANCHOR: parse_function fn parse_function(&mut self) -> Option { + let start_position = self.skip_whitespace(); let name = self.word()?; + let name_span = self.span_from(start_position); let name: FunctionId = FunctionId::new(self.db, name); // ^^^^^^^^^^^^^^^ // Create a new interned struct. @@ -138,7 +150,7 @@ impl Parser<'_> { self.ch(')')?; self.ch('=')?; let body = self.parse_expression()?; - Some(Function::new(self.db, name, args, body)) + Some(Function::new(self.db, name, name_span, args, body)) // ^^^^^^^^^^^^^ // Create a new entity struct. } @@ -149,9 +161,9 @@ impl Parser<'_> { } fn low_op(&mut self) -> Option { - if self.ch('+').is_some() { + if let Some(_) = self.ch('+') { Some(Op::Add) - } else if self.ch('-').is_some() { + } else if let Some(_) = self.ch('-') { Some(Op::Subtract) } else { None @@ -166,9 +178,9 @@ impl Parser<'_> { } fn high_op(&mut self) -> Option { - if self.ch('*').is_some() { + if let Some(_) = self.ch('*') { Some(Op::Multiply) - } else if self.ch('/').is_some() { + } else if let Some(_) = self.ch('/') { Some(Op::Divide) } else { None @@ -180,11 +192,15 @@ impl Parser<'_> { mut parse_expr: impl FnMut(&mut Self) -> Option, mut op: impl FnMut(&mut Self) -> Option, ) -> Option { + let start_position = self.skip_whitespace(); let mut expr1 = parse_expr(self)?; while let Some(op) = op(self) { let expr2 = parse_expr(self)?; - expr1 = Expression::new(self.db, ExpressionData::Op(expr1, op, expr2)); + expr1 = Expression::new( + self.span_from(start_position), + ExpressionData::Op(Box::new(expr1), op, Box::new(expr2)), + ); } Some(expr1) @@ -194,22 +210,29 @@ impl Parser<'_> { /// /// On failure, skips arbitrary tokens. fn parse_expression2(&mut self) -> Option { + let start_position = self.skip_whitespace(); if let Some(w) = self.word() { - if let Some(()) = self.ch('(') { + if let Some(_) = self.ch('(') { let f = FunctionId::new(self.db, w); let args = self.parse_expressions()?; self.ch(')')?; - return Some(Expression::new(self.db, ExpressionData::Call(f, args))); + return Some(Expression::new( + self.span_from(start_position), + ExpressionData::Call(f, args), + )); } let v = VariableId::new(self.db, w); - Some(Expression::new(self.db, ExpressionData::Variable(v))) + Some(Expression::new( + self.span_from(start_position), + ExpressionData::Variable(v), + )) } else if let Some(n) = self.number() { Some(Expression::new( - self.db, + self.span_from(start_position), ExpressionData::Number(OrderedFloat::from(n)), )) - } else if let Some(()) = self.ch('(') { + } else if let Some(_) = self.ch('(') { let expr = self.parse_expression()?; self.ch(')')?; Some(expr) @@ -249,12 +272,12 @@ impl Parser<'_> { /// Parses a single character. /// /// Even on failure, only skips whitespace. - fn ch(&mut self, c: char) -> Option<()> { - self.skip_whitespace(); + fn ch(&mut self, c: char) -> Option { + let start_position = self.skip_whitespace(); match self.peek() { Some(p) if c == p => { self.consume(c); - Some(()) + Some(self.span_from(start_position)) } _ => None, } @@ -269,6 +292,7 @@ impl Parser<'_> { // In this loop, if we consume any characters, we always // return `Some`. let mut s = String::new(); + let position = self.position; while let Some(ch) = self.peek() { if ch.is_alphabetic() || ch == '_' { s.push(ch); @@ -292,7 +316,7 @@ impl Parser<'_> { /// /// Even on failure, only skips whitespace. fn number(&mut self) -> Option { - self.skip_whitespace(); + let start_position = self.skip_whitespace(); self.probe(|this| { // 👆 We need the call to `probe` here because we could consume @@ -354,11 +378,49 @@ fn parse_print() { ( Program { statements: [ - Statement( - Id { - value: 1, - }, - ), + Statement { + span: Span( + Id { + value: 5, + }, + ), + data: Print( + Expression { + span: Span( + Id { + value: 4, + }, + ), + data: Op( + Expression { + span: Span( + Id { + value: 1, + }, + ), + data: Number( + OrderedFloat( + 1.0, + ), + ), + }, + Add, + Expression { + span: Span( + Id { + value: 3, + }, + ), + data: Number( + OrderedFloat( + 2.0, + ), + ), + }, + ), + }, + ), + }, ], }, [], @@ -382,31 +444,163 @@ fn parse_example() { ( Program { statements: [ - Statement( - Id { - value: 1, - }, - ), - Statement( - Id { - value: 2, - }, - ), - Statement( - Id { - value: 3, - }, - ), - Statement( - Id { - value: 4, - }, - ), - Statement( - Id { - value: 5, - }, - ), + Statement { + span: Span( + Id { + value: 10, + }, + ), + data: Function( + Function( + Id { + value: 1, + }, + ), + ), + }, + Statement { + span: Span( + Id { + value: 22, + }, + ), + data: Function( + Function( + Id { + value: 2, + }, + ), + ), + }, + Statement { + span: Span( + Id { + value: 29, + }, + ), + data: Print( + Expression { + span: Span( + Id { + value: 28, + }, + ), + data: Call( + FunctionId( + Id { + value: 1, + }, + ), + [ + Expression { + span: Span( + Id { + value: 24, + }, + ), + data: Number( + OrderedFloat( + 3.0, + ), + ), + }, + Expression { + span: Span( + Id { + value: 26, + }, + ), + data: Number( + OrderedFloat( + 4.0, + ), + ), + }, + ], + ), + }, + ), + }, + Statement { + span: Span( + Id { + value: 34, + }, + ), + data: Print( + Expression { + span: Span( + Id { + value: 33, + }, + ), + data: Call( + FunctionId( + Id { + value: 2, + }, + ), + [ + Expression { + span: Span( + Id { + value: 31, + }, + ), + data: Number( + OrderedFloat( + 1.0, + ), + ), + }, + ], + ), + }, + ), + }, + Statement { + span: Span( + Id { + value: 39, + }, + ), + data: Print( + Expression { + span: Span( + Id { + value: 38, + }, + ), + data: Op( + Expression { + span: Span( + Id { + value: 35, + }, + ), + data: Number( + OrderedFloat( + 11.0, + ), + ), + }, + Multiply, + Expression { + span: Span( + Id { + value: 37, + }, + ), + data: Number( + OrderedFloat( + 2.0, + ), + ), + }, + ), + }, + ), + }, ], }, [], @@ -440,11 +634,93 @@ fn parse_precedence() { ( Program { statements: [ - Statement( - Id { - value: 1, - }, - ), + Statement { + span: Span( + Id { + value: 11, + }, + ), + data: Print( + Expression { + span: Span( + Id { + value: 10, + }, + ), + data: Op( + Expression { + span: Span( + Id { + value: 7, + }, + ), + data: Op( + Expression { + span: Span( + Id { + value: 1, + }, + ), + data: Number( + OrderedFloat( + 1.0, + ), + ), + }, + Add, + Expression { + span: Span( + Id { + value: 6, + }, + ), + data: Op( + Expression { + span: Span( + Id { + value: 3, + }, + ), + data: Number( + OrderedFloat( + 2.0, + ), + ), + }, + Multiply, + Expression { + span: Span( + Id { + value: 5, + }, + ), + data: Number( + OrderedFloat( + 3.0, + ), + ), + }, + ), + }, + ), + }, + Add, + Expression { + span: Span( + Id { + value: 9, + }, + ), + data: Number( + OrderedFloat( + 4.0, + ), + ), + }, + ), + }, + ), + }, ], }, [],