use std::any::TypeId;
use bevy_ecs::{
change_detection::MutUntyped, prelude::*, world::unsafe_world_cell::UnsafeWorldCell,
};
use bevy_reflect::{Reflect, ReflectFromPtr, TypeRegistry};
use smallvec::{smallvec, SmallVec};
#[derive(Debug)]
pub enum Error {
NoAccessToResource(TypeId),
NoAccessToComponent(EntityComponent),
ResourceDoesNotExist(TypeId),
ComponentDoesNotExist(EntityComponent),
NoComponentId(TypeId),
NoTypeRegistration(TypeId),
NoTypeData(TypeId, &'static str),
}
type EntityComponent = (Entity, TypeId);
pub struct RestrictedWorldView<'w> {
world: UnsafeWorldCell<'w>,
resources: Allowed<TypeId>,
components: Allowed<EntityComponent>,
}
#[derive(Clone)]
enum Allowed<T> {
AllowList(SmallVec<[T; 2]>),
ForbidList(SmallVec<[T; 2]>),
}
impl<T: Clone + PartialEq> Allowed<T> {
fn allow_just(value: T) -> Allowed<T> {
Allowed::AllowList(smallvec![value])
}
fn allow(values: impl IntoIterator<Item = T>) -> Allowed<T> {
Allowed::AllowList(values.into_iter().collect())
}
fn everything() -> Allowed<T> {
Allowed::ForbidList(SmallVec::new())
}
fn nothing() -> Allowed<T> {
Allowed::AllowList(SmallVec::new())
}
fn allows_access_to(&self, value: T) -> bool {
match self {
Allowed::AllowList(list) => list.contains(&value),
Allowed::ForbidList(list) => !list.contains(&value),
}
}
fn without(&self, value: T) -> Allowed<T> {
match self {
Allowed::AllowList(list) => {
let position = list
.iter()
.position(|item| *item == value)
.expect("called `without` without access");
let mut new = list.clone();
new.swap_remove(position);
Allowed::AllowList(new)
}
Allowed::ForbidList(list) => {
let mut new = list.clone();
new.push(value);
Allowed::ForbidList(new)
}
}
}
fn without_many(&self, values: impl Iterator<Item = T>) -> Allowed<T>
where
T: Copy,
{
match self {
Allowed::AllowList(list) => {
let new = list.clone();
for value in values {
let position = list
.iter()
.position(|item| *item == value)
.expect("called `without` without access");
let mut new = list.clone();
new.swap_remove(position);
}
Allowed::AllowList(new)
}
Allowed::ForbidList(list) => {
let mut new = list.clone();
new.extend(values);
Allowed::ForbidList(new)
}
}
}
}
impl<'a> From<&'a mut World> for RestrictedWorldView<'a> {
fn from(value: &'a mut World) -> Self {
RestrictedWorldView::new(value)
}
}
impl<'w> RestrictedWorldView<'w> {
pub fn new(world: &'w mut World) -> RestrictedWorldView<'w> {
RestrictedWorldView {
world: world.as_unsafe_world_cell(),
resources: Allowed::everything(),
components: Allowed::everything(),
}
}
pub fn resources_components(
world: &'w mut World,
) -> (RestrictedWorldView<'w>, RestrictedWorldView<'w>) {
let world = world.as_unsafe_world_cell();
let resources = RestrictedWorldView {
world,
resources: Allowed::everything(),
components: Allowed::nothing(),
};
let components = RestrictedWorldView {
world,
resources: Allowed::nothing(),
components: Allowed::everything(),
};
(resources, components)
}
pub fn world(&self) -> UnsafeWorldCell<'w> {
self.world
}
pub fn allows_access_to_resource(&self, type_id: TypeId) -> bool {
self.resources.allows_access_to(type_id)
}
pub fn allows_access_to_component(&self, component: EntityComponent) -> bool {
self.components.allows_access_to(component)
}
pub fn split_off_resource(
&mut self,
resource: TypeId,
) -> (RestrictedWorldView<'_>, RestrictedWorldView<'_>) {
assert!(self.allows_access_to_resource(resource));
let split = RestrictedWorldView {
world: self.world,
resources: Allowed::allow_just(resource),
components: Allowed::nothing(),
};
let rest = RestrictedWorldView {
world: self.world,
resources: self.resources.without(resource),
components: self.components.clone(),
};
(split, rest)
}
pub fn split_off_resource_typed<R: Resource>(
self,
) -> Option<(Mut<'w, R>, RestrictedWorldView<'w>)> {
let type_id = TypeId::of::<R>();
assert!(self.allows_access_to_resource(type_id));
let resource = unsafe { self.world().get_resource_mut::<R>()? };
let rest = RestrictedWorldView {
world: self.world,
resources: self.resources.without(type_id),
components: self.components,
};
Some((resource, rest))
}
pub fn split_off_component(
&mut self,
component: EntityComponent,
) -> (RestrictedWorldView<'_>, RestrictedWorldView<'_>) {
assert!(self.allows_access_to_component(component));
let split = RestrictedWorldView {
world: self.world,
resources: Allowed::nothing(),
components: Allowed::allow_just(component),
};
let rest = RestrictedWorldView {
world: self.world,
resources: self.resources.clone(),
components: self.components.without(component),
};
(split, rest)
}
pub fn split_off_components(
&mut self,
components: impl Iterator<Item = EntityComponent> + Copy,
) -> (RestrictedWorldView<'_>, RestrictedWorldView<'_>) {
for component in components {
assert!(self.allows_access_to_component(component));
}
let split = RestrictedWorldView {
world: self.world,
resources: Allowed::nothing(),
components: Allowed::allow(components),
};
let rest = RestrictedWorldView {
world: self.world,
resources: self.resources.clone(),
components: self.components.without_many(components),
};
(split, rest)
}
}
impl<'w> RestrictedWorldView<'w> {
pub fn contains_entity(&self, entity: Entity) -> bool {
self.world().entities().contains(entity)
}
pub fn get_resource_mut<R: Resource>(&mut self) -> Result<Mut<'_, R>, Error> {
unsafe { self.get_resource_unchecked_mut() }
}
pub fn get_two_resources_mut<R1: Resource, R2: Resource>(
&mut self,
) -> (Result<Mut<'_, R1>, Error>, Result<Mut<'_, R2>, Error>) {
assert_ne!(TypeId::of::<R1>(), TypeId::of::<R2>());
let r1 = unsafe { self.get_resource_unchecked_mut::<R1>() };
let r2 = unsafe { self.get_resource_unchecked_mut::<R2>() };
(r1, r2)
}
unsafe fn get_resource_unchecked_mut<R: Resource>(&self) -> Result<Mut<'_, R>, Error> {
let type_id = TypeId::of::<R>();
if !self.allows_access_to_resource(type_id) {
return Err(Error::NoAccessToResource(type_id));
}
let value = unsafe {
self.world()
.get_resource_mut::<R>()
.ok_or(Error::ResourceDoesNotExist(type_id))?
};
Ok(value)
}
pub fn get_resource_reflect_mut_by_id(
&mut self,
type_id: TypeId,
type_registry: &TypeRegistry,
) -> Result<(&'_ mut dyn Reflect, impl FnOnce() + '_), Error> {
if !self.allows_access_to_resource(type_id) {
return Err(Error::NoAccessToResource(type_id));
}
let component_id = self
.world()
.components()
.get_resource_id(type_id)
.ok_or(Error::ResourceDoesNotExist(type_id))?;
let value = unsafe {
self.world()
.get_resource_mut_by_id(component_id)
.ok_or(Error::ResourceDoesNotExist(type_id))?
};
let value = unsafe { mut_untyped_to_reflect(value, type_registry, type_id)? };
Ok(value)
}
pub fn get_entity_component_reflect(
&mut self,
entity: Entity,
component: TypeId,
type_registry: &TypeRegistry,
) -> Result<(&'_ mut dyn Reflect, bool, impl FnOnce() + '_), Error> {
if !self.allows_access_to_component((entity, component)) {
return Err(Error::NoAccessToComponent((entity, component)));
}
let component_id = self
.world()
.components()
.get_id(component)
.ok_or(Error::NoComponentId(component))?;
let value = unsafe {
self.world()
.get_entity(entity)
.ok_or(Error::ComponentDoesNotExist((entity, component)))?
.get_mut_by_id(component_id)
.ok_or(Error::ComponentDoesNotExist((entity, component)))?
};
let changed = value.is_changed();
let (value, set_changed) =
unsafe { mut_untyped_to_reflect(value, type_registry, component) }?;
Ok((value, changed, set_changed))
}
pub(crate) unsafe fn get_entity_component_reflect_unchecked(
&self,
entity: Entity,
component: TypeId,
type_registry: &TypeRegistry,
) -> Result<(&'_ mut dyn Reflect, impl FnOnce() + '_), Error> {
if !self.allows_access_to_component((entity, component)) {
return Err(Error::NoAccessToComponent((entity, component)));
}
let component_id = self
.world()
.components()
.get_id(component)
.ok_or(Error::NoComponentId(component))?;
let value = unsafe {
self.world()
.get_entity(entity)
.ok_or(Error::ComponentDoesNotExist((entity, component)))?
.get_mut_by_id(component_id)
.ok_or(Error::ComponentDoesNotExist((entity, component)))?
};
unsafe { mut_untyped_to_reflect(value, type_registry, component) }
}
}
unsafe fn mut_untyped_to_reflect<'a>(
value: MutUntyped<'a>,
type_registry: &TypeRegistry,
type_id: TypeId,
) -> Result<(&'a mut dyn Reflect, impl FnOnce() + 'a), Error> {
let registration = type_registry
.get(type_id)
.ok_or(Error::NoTypeRegistration(type_id))?;
let reflect_from_ptr = registration
.data::<ReflectFromPtr>()
.ok_or(Error::NoTypeData(type_id, "ReflectFromPtr"))?;
let (ptr, set_changed) = crate::utils::mut_untyped_split(value);
assert_eq!(reflect_from_ptr.type_id(), type_id);
let value = unsafe { reflect_from_ptr.as_reflect_mut(ptr) };
Ok((value, set_changed))
}
#[cfg(test)]
mod tests {
use std::any::TypeId;
use bevy_ecs::prelude::*;
use bevy_reflect::{Reflect, TypeRegistry};
use super::RestrictedWorldView;
#[derive(Resource)]
struct A(String);
#[derive(Resource, Reflect, Default)]
#[reflect(Resource)]
struct B(String);
#[test]
fn disjoint_resource_access() {
let mut world = World::new();
world.insert_resource(A("a".to_string()));
world.insert_resource(B("b".to_string()));
let mut world = RestrictedWorldView::new(&mut world);
let (mut a_view, mut world) = world.split_off_resource(TypeId::of::<A>());
let mut a = a_view.get_resource_mut::<A>().unwrap();
let mut b = world.get_resource_mut::<B>().unwrap();
a.0.clear();
b.0.clear();
}
#[test]
fn disjoint_resource_access_by_id() {
let mut world = World::new();
world.insert_resource(A("a".to_string()));
world.insert_resource(B("b".to_string()));
let mut world = RestrictedWorldView::new(&mut world);
let (mut a_view, mut world) = world.split_off_resource(TypeId::of::<A>());
let mut a = a_view.get_resource_mut::<A>().unwrap();
let mut type_registry = TypeRegistry::empty();
type_registry.register::<B>();
let b = world
.get_resource_reflect_mut_by_id(TypeId::of::<B>(), &type_registry)
.unwrap();
a.0.clear();
b.0.downcast_mut::<B>().unwrap().0.clear();
}
#[test]
fn get_two_resources_mut() {
let mut world = World::new();
world.insert_resource(A("a".to_string()));
world.insert_resource(B("b".to_string()));
let mut world = RestrictedWorldView::new(&mut world);
let (a, b) = world.get_two_resources_mut::<A, B>();
a.unwrap().0.clear();
b.unwrap().0.clear();
}
#[test]
fn invalid_resource_access() {
let mut world = World::new();
let mut world = RestrictedWorldView::new(&mut world);
let (a_view, mut a_remaining) = world.split_off_resource(TypeId::of::<A>());
assert!(a_view.allows_access_to_resource(TypeId::of::<A>()));
assert!(!a_remaining.allows_access_to_resource(TypeId::of::<A>()));
assert!(!a_view.allows_access_to_resource(TypeId::of::<B>()));
assert!(a_remaining.allows_access_to_resource(TypeId::of::<B>()));
let (b_view, b_remaining) = a_remaining.split_off_resource(TypeId::of::<B>());
assert!(b_view.allows_access_to_resource(TypeId::of::<B>()));
assert!(!b_remaining.allows_access_to_resource(TypeId::of::<B>()));
}
#[derive(Component, Reflect)]
struct ComponentA(String);
#[test]
fn disjoint_component_access() {
let mut type_registry = TypeRegistry::empty();
type_registry.register::<ComponentA>();
type_registry.register::<String>();
let mut world = World::new();
world.insert_resource(A("a".to_string()));
let entity = world.spawn(ComponentA("a".to_string())).id();
let mut world = RestrictedWorldView::new(&mut world);
let (mut component_view, mut world) =
world.split_off_component((entity, TypeId::of::<ComponentA>()));
let component = component_view
.get_entity_component_reflect(entity, TypeId::of::<ComponentA>(), &type_registry)
.unwrap();
let mut resource = world.get_resource_mut::<A>().unwrap();
component.0.downcast_mut::<ComponentA>().unwrap().0.clear();
resource.0.clear();
}
}